786113f19c41770f5c4ba67dd1ee22e98b08271f
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2011-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <folly/SocketAddress.h>
19 #include <folly/io/Cursor.h>
20 #include <folly/io/async/AsyncSSLSocket.h>
21 #include <folly/io/async/EventBase.h>
22 #include <folly/portability/GMock.h>
23 #include <folly/portability/GTest.h>
24 #include <folly/portability/OpenSSL.h>
25 #include <folly/portability/Sockets.h>
26 #include <folly/portability/Unistd.h>
27
28 #include <folly/io/async/test/BlockingSocket.h>
29
30 #include <fcntl.h>
31 #include <signal.h>
32 #include <sys/types.h>
33
34 #include <fstream>
35 #include <iostream>
36 #include <list>
37 #include <set>
38 #include <thread>
39
40 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
41 #include <sys/utsname.h>
42 #endif
43
44 using std::string;
45 using std::vector;
46 using std::min;
47 using std::cerr;
48 using std::endl;
49 using std::list;
50
51 using namespace testing;
52
53 namespace folly {
54 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
55 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
56 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
57
58 constexpr size_t SSLClient::kMaxReadBufferSz;
59 constexpr size_t SSLClient::kMaxReadsPerEvent;
60
61 void getfds(int fds[2]) {
62   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
63     FAIL() << "failed to create socketpair: " << strerror(errno);
64   }
65   for (int idx = 0; idx < 2; ++idx) {
66     int flags = fcntl(fds[idx], F_GETFL, 0);
67     if (flags == -1) {
68       FAIL() << "failed to get flags for socket " << idx << ": "
69              << strerror(errno);
70     }
71     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
72       FAIL() << "failed to put socket " << idx
73              << " in non-blocking mode: " << strerror(errno);
74     }
75   }
76 }
77
78 void getctx(
79     std::shared_ptr<folly::SSLContext> clientCtx,
80     std::shared_ptr<folly::SSLContext> serverCtx) {
81   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
82
83   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
84   serverCtx->loadCertificate(kTestCert);
85   serverCtx->loadPrivateKey(kTestKey);
86 }
87
88 void sslsocketpair(
89     EventBase* eventBase,
90     AsyncSSLSocket::UniquePtr* clientSock,
91     AsyncSSLSocket::UniquePtr* serverSock) {
92   auto clientCtx = std::make_shared<folly::SSLContext>();
93   auto serverCtx = std::make_shared<folly::SSLContext>();
94   int fds[2];
95   getfds(fds);
96   getctx(clientCtx, serverCtx);
97   clientSock->reset(new AsyncSSLSocket(clientCtx, eventBase, fds[0], false));
98   serverSock->reset(new AsyncSSLSocket(serverCtx, eventBase, fds[1], true));
99
100   // (*clientSock)->setSendTimeout(100);
101   // (*serverSock)->setSendTimeout(100);
102 }
103
104 // client protocol filters
105 bool clientProtoFilterPickPony(
106     unsigned char** client,
107     unsigned int* client_len,
108     const unsigned char*,
109     unsigned int) {
110   // the protocol string in length prefixed byte string. the
111   // length byte is not included in the length
112   static unsigned char p[7] = {6, 'p', 'o', 'n', 'i', 'e', 's'};
113   *client = p;
114   *client_len = 7;
115   return true;
116 }
117
118 bool clientProtoFilterPickNone(
119     unsigned char**,
120     unsigned int*,
121     const unsigned char*,
122     unsigned int) {
123   return false;
124 }
125
126 std::string getFileAsBuf(const char* fileName) {
127   std::string buffer;
128   folly::readFile(fileName, buffer);
129   return buffer;
130 }
131
132 std::string getCommonName(X509* cert) {
133   X509_NAME* subject = X509_get_subject_name(cert);
134   std::string cn;
135   cn.resize(ub_common_name);
136   X509_NAME_get_text_by_NID(
137       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
138   return cn;
139 }
140
141 TEST(AsyncSSLSocketTest, ClientCertValidationResultTest) {
142   EventBase ev;
143   int fd = 0;
144
145   AsyncSSLSocket::UniquePtr sock(
146       new AsyncSSLSocket(std::make_shared<SSLContext>(), &ev, fd, false));
147
148   // Initially the cert is not validated, so no result is available.
149   EXPECT_EQ(nullptr, get_pointer(sock->getClientCertValidationResult()));
150
151   sock->setClientCertValidationResult(
152       make_optional(AsyncSSLSocket::CertValidationResult::CERT_VALID));
153
154   EXPECT_EQ(
155       AsyncSSLSocket::CertValidationResult::CERT_VALID,
156       *sock->getClientCertValidationResult());
157 }
158
159 /**
160  * Test connecting to, writing to, reading from, and closing the
161  * connection to the SSL server.
162  */
163 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
164   // Start listening on a local port
165   WriteCallbackBase writeCallback;
166   ReadCallback readCallback(&writeCallback);
167   HandshakeCallback handshakeCallback(&readCallback);
168   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
169   TestSSLServer server(&acceptCallback);
170
171   // Set up SSL context.
172   std::shared_ptr<SSLContext> sslContext(new SSLContext());
173   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
174   // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
175   // sslContext->authenticate(true, false);
176
177   // connect
178   auto socket =
179       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
180   socket->open(std::chrono::milliseconds(10000));
181
182   // write()
183   uint8_t buf[128];
184   memset(buf, 'a', sizeof(buf));
185   socket->write(buf, sizeof(buf));
186
187   // read()
188   uint8_t readbuf[128];
189   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
190   EXPECT_EQ(bytesRead, 128);
191   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
192
193   // close()
194   socket->close();
195
196   cerr << "ConnectWriteReadClose test completed" << endl;
197   EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
198 }
199
200 /**
201  * Test reading after server close.
202  */
203 TEST(AsyncSSLSocketTest, ReadAfterClose) {
204   // Start listening on a local port
205   WriteCallbackBase writeCallback;
206   ReadEOFCallback readCallback(&writeCallback);
207   HandshakeCallback handshakeCallback(&readCallback);
208   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
209   auto server = std::make_unique<TestSSLServer>(&acceptCallback);
210
211   // Set up SSL context.
212   auto sslContext = std::make_shared<SSLContext>();
213   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
214
215   auto socket =
216       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
217   socket->open();
218
219   // This should trigger an EOF on the client.
220   auto evb = handshakeCallback.getSocket()->getEventBase();
221   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
222   std::array<uint8_t, 128> readbuf;
223   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
224   EXPECT_EQ(0, bytesRead);
225 }
226
227 /**
228  * Test bad renegotiation
229  */
230 #if !defined(OPENSSL_IS_BORINGSSL)
231 TEST(AsyncSSLSocketTest, Renegotiate) {
232   EventBase eventBase;
233   auto clientCtx = std::make_shared<SSLContext>();
234   auto dfServerCtx = std::make_shared<SSLContext>();
235   std::array<int, 2> fds;
236   getfds(fds.data());
237   getctx(clientCtx, dfServerCtx);
238
239   AsyncSSLSocket::UniquePtr clientSock(
240       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
241   AsyncSSLSocket::UniquePtr serverSock(
242       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
243   SSLHandshakeClient client(std::move(clientSock), true, true);
244   RenegotiatingServer server(std::move(serverSock));
245
246   while (!client.handshakeSuccess_ && !client.handshakeError_) {
247     eventBase.loopOnce();
248   }
249
250   ASSERT_TRUE(client.handshakeSuccess_);
251
252   auto sslSock = std::move(client).moveSocket();
253   sslSock->detachEventBase();
254   // This is nasty, however we don't want to add support for
255   // renegotiation in AsyncSSLSocket.
256   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
257
258   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
259
260   std::thread t([&]() { eventBase.loopForever(); });
261
262   // Trigger the renegotiation.
263   std::array<uint8_t, 128> buf;
264   memset(buf.data(), 'a', buf.size());
265   try {
266     socket->write(buf.data(), buf.size());
267   } catch (AsyncSocketException& e) {
268     LOG(INFO) << "client got error " << e.what();
269   }
270   eventBase.terminateLoopSoon();
271   t.join();
272
273   eventBase.loop();
274   ASSERT_TRUE(server.renegotiationError_);
275 }
276 #endif
277
278 /**
279  * Negative test for handshakeError().
280  */
281 TEST(AsyncSSLSocketTest, HandshakeError) {
282   // Start listening on a local port
283   WriteCallbackBase writeCallback;
284   WriteErrorCallback readCallback(&writeCallback);
285   HandshakeCallback handshakeCallback(&readCallback);
286   HandshakeErrorCallback acceptCallback(&handshakeCallback);
287   TestSSLServer server(&acceptCallback);
288
289   // Set up SSL context.
290   std::shared_ptr<SSLContext> sslContext(new SSLContext());
291   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
292
293   // connect
294   auto socket =
295       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
296   // read()
297   bool ex = false;
298   try {
299     socket->open();
300
301     uint8_t readbuf[128];
302     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
303     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
304   } catch (AsyncSocketException&) {
305     ex = true;
306   }
307   EXPECT_TRUE(ex);
308
309   // close()
310   socket->close();
311   cerr << "HandshakeError test completed" << endl;
312 }
313
314 /**
315  * Negative test for readError().
316  */
317 TEST(AsyncSSLSocketTest, ReadError) {
318   // Start listening on a local port
319   WriteCallbackBase writeCallback;
320   ReadErrorCallback readCallback(&writeCallback);
321   HandshakeCallback handshakeCallback(&readCallback);
322   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
323   TestSSLServer server(&acceptCallback);
324
325   // Set up SSL context.
326   std::shared_ptr<SSLContext> sslContext(new SSLContext());
327   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
328
329   // connect
330   auto socket =
331       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
332   socket->open();
333
334   // write something to trigger ssl handshake
335   uint8_t buf[128];
336   memset(buf, 'a', sizeof(buf));
337   socket->write(buf, sizeof(buf));
338
339   socket->close();
340   cerr << "ReadError test completed" << endl;
341 }
342
343 /**
344  * Negative test for writeError().
345  */
346 TEST(AsyncSSLSocketTest, WriteError) {
347   // Start listening on a local port
348   WriteCallbackBase writeCallback;
349   WriteErrorCallback readCallback(&writeCallback);
350   HandshakeCallback handshakeCallback(&readCallback);
351   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
352   TestSSLServer server(&acceptCallback);
353
354   // Set up SSL context.
355   std::shared_ptr<SSLContext> sslContext(new SSLContext());
356   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
357
358   // connect
359   auto socket =
360       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
361   socket->open();
362
363   // write something to trigger ssl handshake
364   uint8_t buf[128];
365   memset(buf, 'a', sizeof(buf));
366   socket->write(buf, sizeof(buf));
367
368   socket->close();
369   cerr << "WriteError test completed" << endl;
370 }
371
372 /**
373  * Test a socket with TCP_NODELAY unset.
374  */
375 TEST(AsyncSSLSocketTest, SocketWithDelay) {
376   // Start listening on a local port
377   WriteCallbackBase writeCallback;
378   ReadCallback readCallback(&writeCallback);
379   HandshakeCallback handshakeCallback(&readCallback);
380   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
381   TestSSLServer server(&acceptCallback);
382
383   // Set up SSL context.
384   std::shared_ptr<SSLContext> sslContext(new SSLContext());
385   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
386
387   // connect
388   auto socket =
389       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
390   socket->open();
391
392   // write()
393   uint8_t buf[128];
394   memset(buf, 'a', sizeof(buf));
395   socket->write(buf, sizeof(buf));
396
397   // read()
398   uint8_t readbuf[128];
399   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
400   EXPECT_EQ(bytesRead, 128);
401   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
402
403   // close()
404   socket->close();
405
406   cerr << "SocketWithDelay test completed" << endl;
407 }
408
409 using NextProtocolTypePair =
410     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
411
412 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
413   // For matching protos
414  public:
415   void SetUp() override {
416     getctx(clientCtx, serverCtx);
417   }
418
419   void connect(bool unset = false) {
420     getfds(fds);
421
422     if (unset) {
423       // unsetting NPN for any of [client, server] is enough to make NPN not
424       // work
425       clientCtx->unsetNextProtocols();
426     }
427
428     AsyncSSLSocket::UniquePtr clientSock(
429         new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
430     AsyncSSLSocket::UniquePtr serverSock(
431         new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
432     client = std::make_unique<NpnClient>(std::move(clientSock));
433     server = std::make_unique<NpnServer>(std::move(serverSock));
434
435     eventBase.loop();
436   }
437
438   void expectProtocol(const std::string& proto) {
439     expectHandshakeSuccess();
440     EXPECT_NE(client->nextProtoLength, 0);
441     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
442     EXPECT_EQ(
443         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
444         0);
445     string selected((const char*)client->nextProto, client->nextProtoLength);
446     EXPECT_EQ(proto, selected);
447   }
448
449   void expectNoProtocol() {
450     expectHandshakeSuccess();
451     EXPECT_EQ(client->nextProtoLength, 0);
452     EXPECT_EQ(server->nextProtoLength, 0);
453     EXPECT_EQ(client->nextProto, nullptr);
454     EXPECT_EQ(server->nextProto, nullptr);
455   }
456
457   void expectProtocolType() {
458     expectHandshakeSuccess();
459     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
460         GetParam().second == SSLContext::NextProtocolType::ANY) {
461       EXPECT_EQ(client->protocolType, server->protocolType);
462     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
463                GetParam().second == SSLContext::NextProtocolType::ANY) {
464       // Well not much we can say
465     } else {
466       expectProtocolType(GetParam());
467     }
468   }
469
470   void expectProtocolType(NextProtocolTypePair expected) {
471     expectHandshakeSuccess();
472     EXPECT_EQ(client->protocolType, expected.first);
473     EXPECT_EQ(server->protocolType, expected.second);
474   }
475
476   void expectHandshakeSuccess() {
477     EXPECT_FALSE(client->except.hasValue())
478         << "client handshake error: " << client->except->what();
479     EXPECT_FALSE(server->except.hasValue())
480         << "server handshake error: " << server->except->what();
481   }
482
483   void expectHandshakeError() {
484     EXPECT_TRUE(client->except.hasValue())
485         << "Expected client handshake error!";
486     EXPECT_TRUE(server->except.hasValue())
487         << "Expected server handshake error!";
488   }
489
490   EventBase eventBase;
491   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
492   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
493   int fds[2];
494   std::unique_ptr<NpnClient> client;
495   std::unique_ptr<NpnServer> server;
496 };
497
498 class NextProtocolTLSExtTest : public NextProtocolTest {
499   // For extended TLS protos
500 };
501
502 class NextProtocolNPNOnlyTest : public NextProtocolTest {
503   // For mismatching protos
504 };
505
506 class NextProtocolMismatchTest : public NextProtocolTest {
507   // For mismatching protos
508 };
509
510 TEST_P(NextProtocolTest, NpnTestOverlap) {
511   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
512   serverCtx->setAdvertisedNextProtocols(
513       {"foo", "bar", "baz"}, GetParam().second);
514
515   connect();
516
517   expectProtocol("baz");
518   expectProtocolType();
519 }
520
521 TEST_P(NextProtocolTest, NpnTestUnset) {
522   // Identical to above test, except that we want unset NPN before
523   // looping.
524   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
525   serverCtx->setAdvertisedNextProtocols(
526       {"foo", "bar", "baz"}, GetParam().second);
527
528   connect(true /* unset */);
529
530   // if alpn negotiation fails, type will appear as npn
531   expectNoProtocol();
532   EXPECT_EQ(client->protocolType, server->protocolType);
533 }
534
535 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
536   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
537   serverCtx->setAdvertisedNextProtocols(
538       {"foo", "bar", "baz"}, GetParam().second);
539
540   connect();
541
542   expectNoProtocol();
543   expectProtocolType(
544       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
545 }
546
547 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
548 // will fail on 1.0.2 before that.
549 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
550   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
551   serverCtx->setAdvertisedNextProtocols(
552       {"foo", "bar", "baz"}, GetParam().second);
553   connect();
554
555   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
556       GetParam().second == SSLContext::NextProtocolType::ALPN) {
557     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
558     // mismatch should result in a fatal alert, but this is the current behavior
559     // on all OpenSSL versions/variants, and we want to know if it changes.
560     expectNoProtocol();
561   }
562 #if FOLLY_OPENSSL_IS_110 || defined(OPENSSL_IS_BORINGSSL)
563   else if (
564       GetParam().first == SSLContext::NextProtocolType::ANY &&
565       GetParam().second == SSLContext::NextProtocolType::ANY) {
566 #if FOLLY_OPENSSL_IS_110
567     // OpenSSL 1.1.0 sends a fatal alert on mismatch, which is probavbly the
568     // correct behavior per RFC7301
569     expectHandshakeError();
570 #else
571     // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
572     // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
573     // NPN *and* ALPN
574     expectNoProtocol();
575 #endif
576   }
577 #endif
578   else {
579     expectProtocol("blub");
580     expectProtocolType(
581         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
582   }
583 }
584
585 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
586   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
587   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
588   serverCtx->setAdvertisedNextProtocols(
589       {"foo", "bar", "baz"}, GetParam().second);
590
591   connect();
592
593   expectProtocol("ponies");
594   expectProtocolType();
595 }
596
597 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
598   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
599   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
600   serverCtx->setAdvertisedNextProtocols(
601       {"foo", "bar", "baz"}, GetParam().second);
602
603   connect();
604
605   expectProtocol("blub");
606   expectProtocolType();
607 }
608
609 TEST_P(NextProtocolTest, RandomizedNpnTest) {
610   // Probability that this test will fail is 2^-64, which could be considered
611   // as negligible.
612   const int kTries = 64;
613
614   clientCtx->setAdvertisedNextProtocols(
615       {"foo", "bar", "baz"}, GetParam().first);
616   serverCtx->setRandomizedAdvertisedNextProtocols(
617       {{1, {"foo"}}, {1, {"bar"}}}, GetParam().second);
618
619   std::set<string> selectedProtocols;
620   for (int i = 0; i < kTries; ++i) {
621     connect();
622
623     EXPECT_NE(client->nextProtoLength, 0);
624     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
625     EXPECT_EQ(
626         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
627         0);
628     string selected((const char*)client->nextProto, client->nextProtoLength);
629     selectedProtocols.insert(selected);
630     expectProtocolType();
631   }
632   EXPECT_EQ(selectedProtocols.size(), 2);
633 }
634
635 INSTANTIATE_TEST_CASE_P(
636     AsyncSSLSocketTest,
637     NextProtocolTest,
638     ::testing::Values(
639         NextProtocolTypePair(
640             SSLContext::NextProtocolType::NPN,
641             SSLContext::NextProtocolType::NPN),
642         NextProtocolTypePair(
643             SSLContext::NextProtocolType::NPN,
644             SSLContext::NextProtocolType::ANY),
645         NextProtocolTypePair(
646             SSLContext::NextProtocolType::ANY,
647             SSLContext::NextProtocolType::ANY)));
648
649 #if FOLLY_OPENSSL_HAS_ALPN
650 INSTANTIATE_TEST_CASE_P(
651     AsyncSSLSocketTest,
652     NextProtocolTLSExtTest,
653     ::testing::Values(
654         NextProtocolTypePair(
655             SSLContext::NextProtocolType::ALPN,
656             SSLContext::NextProtocolType::ALPN),
657         NextProtocolTypePair(
658             SSLContext::NextProtocolType::ALPN,
659             SSLContext::NextProtocolType::ANY),
660         NextProtocolTypePair(
661             SSLContext::NextProtocolType::ANY,
662             SSLContext::NextProtocolType::ALPN)));
663 #endif
664
665 INSTANTIATE_TEST_CASE_P(
666     AsyncSSLSocketTest,
667     NextProtocolNPNOnlyTest,
668     ::testing::Values(NextProtocolTypePair(
669         SSLContext::NextProtocolType::NPN,
670         SSLContext::NextProtocolType::NPN)));
671
672 #if FOLLY_OPENSSL_HAS_ALPN
673 INSTANTIATE_TEST_CASE_P(
674     AsyncSSLSocketTest,
675     NextProtocolMismatchTest,
676     ::testing::Values(
677         NextProtocolTypePair(
678             SSLContext::NextProtocolType::NPN,
679             SSLContext::NextProtocolType::ALPN),
680         NextProtocolTypePair(
681             SSLContext::NextProtocolType::ALPN,
682             SSLContext::NextProtocolType::NPN)));
683 #endif
684
685 #ifndef OPENSSL_NO_TLSEXT
686 /**
687  * 1. Client sends TLSEXT_HOSTNAME in client hello.
688  * 2. Server found a match SSL_CTX and use this SSL_CTX to
689  *    continue the SSL handshake.
690  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
691  */
692 TEST(AsyncSSLSocketTest, SNITestMatch) {
693   EventBase eventBase;
694   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
695   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
696   // Use the same SSLContext to continue the handshake after
697   // tlsext_hostname match.
698   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
699   const std::string serverName("xyz.newdev.facebook.com");
700   int fds[2];
701   getfds(fds);
702   getctx(clientCtx, dfServerCtx);
703
704   AsyncSSLSocket::UniquePtr clientSock(
705       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
706   AsyncSSLSocket::UniquePtr serverSock(
707       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
708   SNIClient client(std::move(clientSock));
709   SNIServer server(
710       std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
711
712   eventBase.loop();
713
714   EXPECT_TRUE(client.serverNameMatch);
715   EXPECT_TRUE(server.serverNameMatch);
716 }
717
718 /**
719  * 1. Client sends TLSEXT_HOSTNAME in client hello.
720  * 2. Server cannot find a matching SSL_CTX and continue to use
721  *    the current SSL_CTX to do the handshake.
722  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
723  */
724 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
725   EventBase eventBase;
726   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
727   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
728   // Use the same SSLContext to continue the handshake after
729   // tlsext_hostname match.
730   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
731   const std::string clientRequestingServerName("foo.com");
732   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
733
734   int fds[2];
735   getfds(fds);
736   getctx(clientCtx, dfServerCtx);
737
738   AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
739       clientCtx, &eventBase, fds[0], clientRequestingServerName));
740   AsyncSSLSocket::UniquePtr serverSock(
741       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
742   SNIClient client(std::move(clientSock));
743   SNIServer server(
744       std::move(serverSock),
745       dfServerCtx,
746       hskServerCtx,
747       serverExpectedServerName);
748
749   eventBase.loop();
750
751   EXPECT_TRUE(!client.serverNameMatch);
752   EXPECT_TRUE(!server.serverNameMatch);
753 }
754 /**
755  * 1. Client sends TLSEXT_HOSTNAME in client hello.
756  * 2. We then change the serverName.
757  * 3. We expect that we get 'false' as the result for serNameMatch.
758  */
759
760 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
761   EventBase eventBase;
762   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
763   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
764   // Use the same SSLContext to continue the handshake after
765   // tlsext_hostname match.
766   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
767   const std::string serverName("xyz.newdev.facebook.com");
768   int fds[2];
769   getfds(fds);
770   getctx(clientCtx, dfServerCtx);
771
772   AsyncSSLSocket::UniquePtr clientSock(
773       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
774   // Change the server name
775   std::string newName("new.com");
776   clientSock->setServerName(newName);
777   AsyncSSLSocket::UniquePtr serverSock(
778       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
779   SNIClient client(std::move(clientSock));
780   SNIServer server(
781       std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
782
783   eventBase.loop();
784
785   EXPECT_TRUE(!client.serverNameMatch);
786 }
787
788 /**
789  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
790  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
791  */
792 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
793   EventBase eventBase;
794   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
795   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
796   // Use the same SSLContext to continue the handshake after
797   // tlsext_hostname match.
798   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
799   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
800
801   int fds[2];
802   getfds(fds);
803   getctx(clientCtx, dfServerCtx);
804
805   AsyncSSLSocket::UniquePtr clientSock(
806       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
807   AsyncSSLSocket::UniquePtr serverSock(
808       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
809   SNIClient client(std::move(clientSock));
810   SNIServer server(
811       std::move(serverSock),
812       dfServerCtx,
813       hskServerCtx,
814       serverExpectedServerName);
815
816   eventBase.loop();
817
818   EXPECT_TRUE(!client.serverNameMatch);
819   EXPECT_TRUE(!server.serverNameMatch);
820 }
821
822 #endif
823 /**
824  * Test SSL client socket
825  */
826 TEST(AsyncSSLSocketTest, SSLClientTest) {
827   // Start listening on a local port
828   WriteCallbackBase writeCallback;
829   ReadCallback readCallback(&writeCallback);
830   HandshakeCallback handshakeCallback(&readCallback);
831   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
832   TestSSLServer server(&acceptCallback);
833
834   // Set up SSL client
835   EventBase eventBase;
836   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
837
838   client->connect();
839   EventBaseAborter eba(&eventBase, 3000);
840   eventBase.loop();
841
842   EXPECT_EQ(client->getMiss(), 1);
843   EXPECT_EQ(client->getHit(), 0);
844
845   cerr << "SSLClientTest test completed" << endl;
846 }
847
848 /**
849  * Test SSL client socket session re-use
850  */
851 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
852   // Start listening on a local port
853   WriteCallbackBase writeCallback;
854   ReadCallback readCallback(&writeCallback);
855   HandshakeCallback handshakeCallback(&readCallback);
856   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
857   TestSSLServer server(&acceptCallback);
858
859   // Set up SSL client
860   EventBase eventBase;
861   auto client =
862       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
863
864   client->connect();
865   EventBaseAborter eba(&eventBase, 3000);
866   eventBase.loop();
867
868   EXPECT_EQ(client->getMiss(), 1);
869   EXPECT_EQ(client->getHit(), 9);
870
871   cerr << "SSLClientTestReuse test completed" << endl;
872 }
873
874 /**
875  * Test SSL client socket timeout
876  */
877 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
878   // Start listening on a local port
879   EmptyReadCallback readCallback;
880   HandshakeCallback handshakeCallback(
881       &readCallback, HandshakeCallback::EXPECT_ERROR);
882   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
883   TestSSLServer server(&acceptCallback);
884
885   // Set up SSL client
886   EventBase eventBase;
887   auto client =
888       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
889   client->connect(true /* write before connect completes */);
890   EventBaseAborter eba(&eventBase, 3000);
891   eventBase.loop();
892
893   usleep(100000);
894   // This is checking that the connectError callback precedes any queued
895   // writeError callbacks.  This matches AsyncSocket's behavior
896   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
897   EXPECT_EQ(client->getErrors(), 1);
898   EXPECT_EQ(client->getMiss(), 0);
899   EXPECT_EQ(client->getHit(), 0);
900
901   cerr << "SSLClientTimeoutTest test completed" << endl;
902 }
903
904 // The next 3 tests need an FB-only extension, and will fail without it
905 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
906 /**
907  * Test SSL server async cache
908  */
909 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
910   // Start listening on a local port
911   WriteCallbackBase writeCallback;
912   ReadCallback readCallback(&writeCallback);
913   HandshakeCallback handshakeCallback(&readCallback);
914   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
915   TestSSLAsyncCacheServer server(&acceptCallback);
916
917   // Set up SSL client
918   EventBase eventBase;
919   auto client =
920       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
921
922   client->connect();
923   EventBaseAborter eba(&eventBase, 3000);
924   eventBase.loop();
925
926   EXPECT_EQ(server.getAsyncCallbacks(), 18);
927   EXPECT_EQ(server.getAsyncLookups(), 9);
928   EXPECT_EQ(client->getMiss(), 10);
929   EXPECT_EQ(client->getHit(), 0);
930
931   cerr << "SSLServerAsyncCacheTest test completed" << endl;
932 }
933
934 /**
935  * Test SSL server accept timeout with cache path
936  */
937 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
938   // Start listening on a local port
939   WriteCallbackBase writeCallback;
940   ReadCallback readCallback(&writeCallback);
941   HandshakeCallback handshakeCallback(&readCallback);
942   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
943   TestSSLAsyncCacheServer server(&acceptCallback);
944
945   // Set up SSL client
946   EventBase eventBase;
947   // only do a TCP connect
948   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
949   sock->connect(nullptr, server.getAddress());
950
951   EmptyReadCallback clientReadCallback;
952   clientReadCallback.tcpSocket_ = sock;
953   sock->setReadCB(&clientReadCallback);
954
955   EventBaseAborter eba(&eventBase, 3000);
956   eventBase.loop();
957
958   EXPECT_EQ(readCallback.state, STATE_WAITING);
959
960   cerr << "SSLServerTimeoutTest test completed" << endl;
961 }
962
963 /**
964  * Test SSL server accept timeout with cache path
965  */
966 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
967   // Start listening on a local port
968   WriteCallbackBase writeCallback;
969   ReadCallback readCallback(&writeCallback);
970   HandshakeCallback handshakeCallback(&readCallback);
971   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
972   TestSSLAsyncCacheServer server(&acceptCallback);
973
974   // Set up SSL client
975   EventBase eventBase;
976   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
977
978   client->connect();
979   EventBaseAborter eba(&eventBase, 3000);
980   eventBase.loop();
981
982   EXPECT_EQ(server.getAsyncCallbacks(), 1);
983   EXPECT_EQ(server.getAsyncLookups(), 1);
984   EXPECT_EQ(client->getErrors(), 1);
985   EXPECT_EQ(client->getMiss(), 1);
986   EXPECT_EQ(client->getHit(), 0);
987
988   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
989 }
990
991 /**
992  * Test SSL server accept timeout with cache path
993  */
994 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
995   // Start listening on a local port
996   WriteCallbackBase writeCallback;
997   ReadCallback readCallback(&writeCallback);
998   HandshakeCallback handshakeCallback(
999       &readCallback, HandshakeCallback::EXPECT_ERROR);
1000   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
1001   TestSSLAsyncCacheServer server(&acceptCallback, 500);
1002
1003   // Set up SSL client
1004   EventBase eventBase;
1005   auto client =
1006       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
1007
1008   client->connect();
1009   EventBaseAborter eba(&eventBase, 3000);
1010   eventBase.loop();
1011
1012   server.getEventBase().runInEventBaseThread(
1013       [&handshakeCallback] { handshakeCallback.closeSocket(); });
1014   // give time for the cache lookup to come back and find it closed
1015   handshakeCallback.waitForHandshake();
1016
1017   EXPECT_EQ(server.getAsyncCallbacks(), 1);
1018   EXPECT_EQ(server.getAsyncLookups(), 1);
1019   EXPECT_EQ(client->getErrors(), 1);
1020   EXPECT_EQ(client->getMiss(), 1);
1021   EXPECT_EQ(client->getHit(), 0);
1022
1023   cerr << "SSLServerCacheCloseTest test completed" << endl;
1024 }
1025 #endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
1026
1027 /**
1028  * Verify Client Ciphers obtained using SSL MSG Callback.
1029  */
1030 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1031   EventBase eventBase;
1032   auto clientCtx = std::make_shared<SSLContext>();
1033   auto serverCtx = std::make_shared<SSLContext>();
1034   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1035   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1036   serverCtx->loadPrivateKey(kTestKey);
1037   serverCtx->loadCertificate(kTestCert);
1038   serverCtx->loadTrustedCertificates(kTestCA);
1039   serverCtx->loadClientCAList(kTestCA);
1040
1041   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1042   clientCtx->ciphers("AES256-SHA:AES128-SHA");
1043   clientCtx->loadPrivateKey(kTestKey);
1044   clientCtx->loadCertificate(kTestCert);
1045   clientCtx->loadTrustedCertificates(kTestCA);
1046
1047   int fds[2];
1048   getfds(fds);
1049
1050   AsyncSSLSocket::UniquePtr clientSock(
1051       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1052   AsyncSSLSocket::UniquePtr serverSock(
1053       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1054
1055   SSLHandshakeClient client(std::move(clientSock), true, true);
1056   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1057
1058   eventBase.loop();
1059
1060 #if defined(OPENSSL_IS_BORINGSSL)
1061   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
1062 #else
1063   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
1064 #endif
1065   EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1066   EXPECT_TRUE(client.handshakeVerify_);
1067   EXPECT_TRUE(client.handshakeSuccess_);
1068   EXPECT_TRUE(!client.handshakeError_);
1069   EXPECT_TRUE(server.handshakeVerify_);
1070   EXPECT_TRUE(server.handshakeSuccess_);
1071   EXPECT_TRUE(!server.handshakeError_);
1072 }
1073
1074 /**
1075  * Verify that server is able to get client cert by getPeerCert() API.
1076  */
1077 TEST(AsyncSSLSocketTest, GetClientCertificate) {
1078   EventBase eventBase;
1079   auto clientCtx = std::make_shared<SSLContext>();
1080   auto serverCtx = std::make_shared<SSLContext>();
1081   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1082   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1083   serverCtx->loadPrivateKey(kTestKey);
1084   serverCtx->loadCertificate(kTestCert);
1085   serverCtx->loadTrustedCertificates(kClientTestCA);
1086   serverCtx->loadClientCAList(kClientTestCA);
1087
1088   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1089   clientCtx->ciphers("AES256-SHA:AES128-SHA");
1090   clientCtx->loadPrivateKey(kClientTestKey);
1091   clientCtx->loadCertificate(kClientTestCert);
1092   clientCtx->loadTrustedCertificates(kTestCA);
1093
1094   std::array<int, 2> fds;
1095   getfds(fds.data());
1096
1097   AsyncSSLSocket::UniquePtr clientSock(
1098       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1099   AsyncSSLSocket::UniquePtr serverSock(
1100       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1101
1102   SSLHandshakeClient client(std::move(clientSock), true, true);
1103   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1104
1105   eventBase.loop();
1106
1107   // Handshake should succeed.
1108   EXPECT_TRUE(client.handshakeSuccess_);
1109   EXPECT_TRUE(server.handshakeSuccess_);
1110
1111   // Reclaim the sockets from SSLHandshakeBase.
1112   auto cliSocket = std::move(client).moveSocket();
1113   auto srvSocket = std::move(server).moveSocket();
1114
1115   // Client cert retrieved from server side.
1116   folly::ssl::X509UniquePtr serverPeerCert = srvSocket->getPeerCert();
1117   CHECK(serverPeerCert);
1118
1119   // Client cert retrieved from client side.
1120   const X509* clientSelfCert = cliSocket->getSelfCert();
1121   CHECK(clientSelfCert);
1122
1123   // The two certs should be the same.
1124   EXPECT_EQ(0, X509_cmp(clientSelfCert, serverPeerCert.get()));
1125 }
1126
1127 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1128   EventBase eventBase;
1129   auto ctx = std::make_shared<SSLContext>();
1130
1131   int fds[2];
1132   getfds(fds);
1133
1134   int bufLen = 42;
1135   uint8_t majorVersion = 18;
1136   uint8_t minorVersion = 25;
1137
1138   // Create callback buf
1139   auto buf = IOBuf::create(bufLen);
1140   buf->append(bufLen);
1141   folly::io::RWPrivateCursor cursor(buf.get());
1142   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1143   cursor.write<uint16_t>(0);
1144   cursor.write<uint8_t>(38);
1145   cursor.write<uint8_t>(majorVersion);
1146   cursor.write<uint8_t>(minorVersion);
1147   cursor.skip(32);
1148   cursor.write<uint32_t>(0);
1149
1150   SSL* ssl = ctx->createSSL();
1151   SCOPE_EXIT {
1152     SSL_free(ssl);
1153   };
1154   AsyncSSLSocket::UniquePtr sock(
1155       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1156   sock->enableClientHelloParsing();
1157
1158   // Test client hello parsing in one packet
1159   AsyncSSLSocket::clientHelloParsingCallback(
1160       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1161   buf.reset();
1162
1163   auto parsedClientHello = sock->getClientHelloInfo();
1164   EXPECT_TRUE(parsedClientHello != nullptr);
1165   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1166   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1167 }
1168
1169 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1170   EventBase eventBase;
1171   auto ctx = std::make_shared<SSLContext>();
1172
1173   int fds[2];
1174   getfds(fds);
1175
1176   int bufLen = 42;
1177   uint8_t majorVersion = 18;
1178   uint8_t minorVersion = 25;
1179
1180   // Create callback buf
1181   auto buf = IOBuf::create(bufLen);
1182   buf->append(bufLen);
1183   folly::io::RWPrivateCursor cursor(buf.get());
1184   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1185   cursor.write<uint16_t>(0);
1186   cursor.write<uint8_t>(38);
1187   cursor.write<uint8_t>(majorVersion);
1188   cursor.write<uint8_t>(minorVersion);
1189   cursor.skip(32);
1190   cursor.write<uint32_t>(0);
1191
1192   SSL* ssl = ctx->createSSL();
1193   SCOPE_EXIT {
1194     SSL_free(ssl);
1195   };
1196   AsyncSSLSocket::UniquePtr sock(
1197       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1198   sock->enableClientHelloParsing();
1199
1200   // Test parsing with two packets with first packet size < 3
1201   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1202   AsyncSSLSocket::clientHelloParsingCallback(
1203       0,
1204       0,
1205       SSL3_RT_HANDSHAKE,
1206       bufCopy->data(),
1207       bufCopy->length(),
1208       ssl,
1209       sock.get());
1210   bufCopy.reset();
1211   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1212   AsyncSSLSocket::clientHelloParsingCallback(
1213       0,
1214       0,
1215       SSL3_RT_HANDSHAKE,
1216       bufCopy->data(),
1217       bufCopy->length(),
1218       ssl,
1219       sock.get());
1220   bufCopy.reset();
1221
1222   auto parsedClientHello = sock->getClientHelloInfo();
1223   EXPECT_TRUE(parsedClientHello != nullptr);
1224   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1225   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1226 }
1227
1228 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1229   EventBase eventBase;
1230   auto ctx = std::make_shared<SSLContext>();
1231
1232   int fds[2];
1233   getfds(fds);
1234
1235   int bufLen = 42;
1236   uint8_t majorVersion = 18;
1237   uint8_t minorVersion = 25;
1238
1239   // Create callback buf
1240   auto buf = IOBuf::create(bufLen);
1241   buf->append(bufLen);
1242   folly::io::RWPrivateCursor cursor(buf.get());
1243   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1244   cursor.write<uint16_t>(0);
1245   cursor.write<uint8_t>(38);
1246   cursor.write<uint8_t>(majorVersion);
1247   cursor.write<uint8_t>(minorVersion);
1248   cursor.skip(32);
1249   cursor.write<uint32_t>(0);
1250
1251   SSL* ssl = ctx->createSSL();
1252   SCOPE_EXIT {
1253     SSL_free(ssl);
1254   };
1255   AsyncSSLSocket::UniquePtr sock(
1256       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1257   sock->enableClientHelloParsing();
1258
1259   // Test parsing with multiple small packets
1260   for (uint64_t i = 0; i < buf->length(); i += 3) {
1261     auto bufCopy = folly::IOBuf::copyBuffer(
1262         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1263     AsyncSSLSocket::clientHelloParsingCallback(
1264         0,
1265         0,
1266         SSL3_RT_HANDSHAKE,
1267         bufCopy->data(),
1268         bufCopy->length(),
1269         ssl,
1270         sock.get());
1271     bufCopy.reset();
1272   }
1273
1274   auto parsedClientHello = sock->getClientHelloInfo();
1275   EXPECT_TRUE(parsedClientHello != nullptr);
1276   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1277   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1278 }
1279
1280 /**
1281  * Verify sucessful behavior of SSL certificate validation.
1282  */
1283 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1284   EventBase eventBase;
1285   auto clientCtx = std::make_shared<SSLContext>();
1286   auto dfServerCtx = std::make_shared<SSLContext>();
1287
1288   int fds[2];
1289   getfds(fds);
1290   getctx(clientCtx, dfServerCtx);
1291
1292   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1293   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1294
1295   AsyncSSLSocket::UniquePtr clientSock(
1296       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1297   AsyncSSLSocket::UniquePtr serverSock(
1298       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1299
1300   SSLHandshakeClient client(std::move(clientSock), true, true);
1301   clientCtx->loadTrustedCertificates(kTestCA);
1302
1303   SSLHandshakeServer server(std::move(serverSock), true, true);
1304
1305   eventBase.loop();
1306
1307   EXPECT_TRUE(client.handshakeVerify_);
1308   EXPECT_TRUE(client.handshakeSuccess_);
1309   EXPECT_TRUE(!client.handshakeError_);
1310   EXPECT_LE(0, client.handshakeTime.count());
1311   EXPECT_TRUE(!server.handshakeVerify_);
1312   EXPECT_TRUE(server.handshakeSuccess_);
1313   EXPECT_TRUE(!server.handshakeError_);
1314   EXPECT_LE(0, server.handshakeTime.count());
1315 }
1316
1317 /**
1318  * Verify that the client's verification callback is able to fail SSL
1319  * connection establishment.
1320  */
1321 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1322   EventBase eventBase;
1323   auto clientCtx = std::make_shared<SSLContext>();
1324   auto dfServerCtx = std::make_shared<SSLContext>();
1325
1326   int fds[2];
1327   getfds(fds);
1328   getctx(clientCtx, dfServerCtx);
1329
1330   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1331   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1332
1333   AsyncSSLSocket::UniquePtr clientSock(
1334       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1335   AsyncSSLSocket::UniquePtr serverSock(
1336       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1337
1338   SSLHandshakeClient client(std::move(clientSock), true, false);
1339   clientCtx->loadTrustedCertificates(kTestCA);
1340
1341   SSLHandshakeServer server(std::move(serverSock), true, true);
1342
1343   eventBase.loop();
1344
1345   EXPECT_TRUE(client.handshakeVerify_);
1346   EXPECT_TRUE(!client.handshakeSuccess_);
1347   EXPECT_TRUE(client.handshakeError_);
1348   EXPECT_LE(0, client.handshakeTime.count());
1349   EXPECT_TRUE(!server.handshakeVerify_);
1350   EXPECT_TRUE(!server.handshakeSuccess_);
1351   EXPECT_TRUE(server.handshakeError_);
1352   EXPECT_LE(0, server.handshakeTime.count());
1353 }
1354
1355 /**
1356  * Verify that the options in SSLContext can be overridden in
1357  * sslConnect/Accept.i.e specifying that no validation should be performed
1358  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1359  * the validation callback.
1360  */
1361 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1362   EventBase eventBase;
1363   auto clientCtx = std::make_shared<SSLContext>();
1364   auto dfServerCtx = std::make_shared<SSLContext>();
1365
1366   int fds[2];
1367   getfds(fds);
1368   getctx(clientCtx, dfServerCtx);
1369
1370   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1371   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1372
1373   AsyncSSLSocket::UniquePtr clientSock(
1374       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1375   AsyncSSLSocket::UniquePtr serverSock(
1376       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1377
1378   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1379   clientCtx->loadTrustedCertificates(kTestCA);
1380
1381   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1382
1383   eventBase.loop();
1384
1385   EXPECT_TRUE(!client.handshakeVerify_);
1386   EXPECT_TRUE(client.handshakeSuccess_);
1387   EXPECT_TRUE(!client.handshakeError_);
1388   EXPECT_LE(0, client.handshakeTime.count());
1389   EXPECT_TRUE(!server.handshakeVerify_);
1390   EXPECT_TRUE(server.handshakeSuccess_);
1391   EXPECT_TRUE(!server.handshakeError_);
1392   EXPECT_LE(0, server.handshakeTime.count());
1393 }
1394
1395 /**
1396  * Verify that the options in SSLContext can be overridden in
1397  * sslConnect/Accept. Enable verification even if context says otherwise.
1398  * Test requireClientCert with client cert
1399  */
1400 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1401   EventBase eventBase;
1402   auto clientCtx = std::make_shared<SSLContext>();
1403   auto serverCtx = std::make_shared<SSLContext>();
1404   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1405   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1406   serverCtx->loadPrivateKey(kTestKey);
1407   serverCtx->loadCertificate(kTestCert);
1408   serverCtx->loadTrustedCertificates(kTestCA);
1409   serverCtx->loadClientCAList(kTestCA);
1410
1411   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1412   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1413   clientCtx->loadPrivateKey(kTestKey);
1414   clientCtx->loadCertificate(kTestCert);
1415   clientCtx->loadTrustedCertificates(kTestCA);
1416
1417   int fds[2];
1418   getfds(fds);
1419
1420   AsyncSSLSocket::UniquePtr clientSock(
1421       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1422   AsyncSSLSocket::UniquePtr serverSock(
1423       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1424
1425   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1426   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1427
1428   eventBase.loop();
1429
1430   EXPECT_TRUE(client.handshakeVerify_);
1431   EXPECT_TRUE(client.handshakeSuccess_);
1432   EXPECT_FALSE(client.handshakeError_);
1433   EXPECT_LE(0, client.handshakeTime.count());
1434   EXPECT_TRUE(server.handshakeVerify_);
1435   EXPECT_TRUE(server.handshakeSuccess_);
1436   EXPECT_FALSE(server.handshakeError_);
1437   EXPECT_LE(0, server.handshakeTime.count());
1438 }
1439
1440 /**
1441  * Verify that the client's verification callback is able to override
1442  * the preverification failure and allow a successful connection.
1443  */
1444 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1445   EventBase eventBase;
1446   auto clientCtx = std::make_shared<SSLContext>();
1447   auto dfServerCtx = std::make_shared<SSLContext>();
1448
1449   int fds[2];
1450   getfds(fds);
1451   getctx(clientCtx, dfServerCtx);
1452
1453   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1454   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1455
1456   AsyncSSLSocket::UniquePtr clientSock(
1457       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1458   AsyncSSLSocket::UniquePtr serverSock(
1459       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1460
1461   SSLHandshakeClient client(std::move(clientSock), false, true);
1462   SSLHandshakeServer server(std::move(serverSock), true, true);
1463
1464   eventBase.loop();
1465
1466   EXPECT_TRUE(client.handshakeVerify_);
1467   EXPECT_TRUE(client.handshakeSuccess_);
1468   EXPECT_TRUE(!client.handshakeError_);
1469   EXPECT_LE(0, client.handshakeTime.count());
1470   EXPECT_TRUE(!server.handshakeVerify_);
1471   EXPECT_TRUE(server.handshakeSuccess_);
1472   EXPECT_TRUE(!server.handshakeError_);
1473   EXPECT_LE(0, server.handshakeTime.count());
1474 }
1475
1476 /**
1477  * Verify that specifying that no validation should be performed allows an
1478  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1479  * callback.
1480  */
1481 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1482   EventBase eventBase;
1483   auto clientCtx = std::make_shared<SSLContext>();
1484   auto dfServerCtx = std::make_shared<SSLContext>();
1485
1486   int fds[2];
1487   getfds(fds);
1488   getctx(clientCtx, dfServerCtx);
1489
1490   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1491   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1492
1493   AsyncSSLSocket::UniquePtr clientSock(
1494       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1495   AsyncSSLSocket::UniquePtr serverSock(
1496       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1497
1498   SSLHandshakeClient client(std::move(clientSock), false, false);
1499   SSLHandshakeServer server(std::move(serverSock), false, false);
1500
1501   eventBase.loop();
1502
1503   EXPECT_TRUE(!client.handshakeVerify_);
1504   EXPECT_TRUE(client.handshakeSuccess_);
1505   EXPECT_TRUE(!client.handshakeError_);
1506   EXPECT_LE(0, client.handshakeTime.count());
1507   EXPECT_TRUE(!server.handshakeVerify_);
1508   EXPECT_TRUE(server.handshakeSuccess_);
1509   EXPECT_TRUE(!server.handshakeError_);
1510   EXPECT_LE(0, server.handshakeTime.count());
1511 }
1512
1513 /**
1514  * Test requireClientCert with client cert
1515  */
1516 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1517   EventBase eventBase;
1518   auto clientCtx = std::make_shared<SSLContext>();
1519   auto serverCtx = std::make_shared<SSLContext>();
1520   serverCtx->setVerificationOption(
1521       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1522   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1523   serverCtx->loadPrivateKey(kTestKey);
1524   serverCtx->loadCertificate(kTestCert);
1525   serverCtx->loadTrustedCertificates(kTestCA);
1526   serverCtx->loadClientCAList(kTestCA);
1527
1528   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1529   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1530   clientCtx->loadPrivateKey(kTestKey);
1531   clientCtx->loadCertificate(kTestCert);
1532   clientCtx->loadTrustedCertificates(kTestCA);
1533
1534   int fds[2];
1535   getfds(fds);
1536
1537   AsyncSSLSocket::UniquePtr clientSock(
1538       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1539   AsyncSSLSocket::UniquePtr serverSock(
1540       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1541
1542   SSLHandshakeClient client(std::move(clientSock), true, true);
1543   SSLHandshakeServer server(std::move(serverSock), true, true);
1544
1545   eventBase.loop();
1546
1547   EXPECT_TRUE(client.handshakeVerify_);
1548   EXPECT_TRUE(client.handshakeSuccess_);
1549   EXPECT_FALSE(client.handshakeError_);
1550   EXPECT_LE(0, client.handshakeTime.count());
1551   EXPECT_TRUE(server.handshakeVerify_);
1552   EXPECT_TRUE(server.handshakeSuccess_);
1553   EXPECT_FALSE(server.handshakeError_);
1554   EXPECT_LE(0, server.handshakeTime.count());
1555 }
1556
1557 /**
1558  * Test requireClientCert with no client cert
1559  */
1560 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1561   EventBase eventBase;
1562   auto clientCtx = std::make_shared<SSLContext>();
1563   auto serverCtx = std::make_shared<SSLContext>();
1564   serverCtx->setVerificationOption(
1565       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1566   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1567   serverCtx->loadPrivateKey(kTestKey);
1568   serverCtx->loadCertificate(kTestCert);
1569   serverCtx->loadTrustedCertificates(kTestCA);
1570   serverCtx->loadClientCAList(kTestCA);
1571   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1572   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1573
1574   int fds[2];
1575   getfds(fds);
1576
1577   AsyncSSLSocket::UniquePtr clientSock(
1578       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1579   AsyncSSLSocket::UniquePtr serverSock(
1580       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1581
1582   SSLHandshakeClient client(std::move(clientSock), false, false);
1583   SSLHandshakeServer server(std::move(serverSock), false, false);
1584
1585   eventBase.loop();
1586
1587   EXPECT_FALSE(server.handshakeVerify_);
1588   EXPECT_FALSE(server.handshakeSuccess_);
1589   EXPECT_TRUE(server.handshakeError_);
1590   EXPECT_LE(0, client.handshakeTime.count());
1591   EXPECT_LE(0, server.handshakeTime.count());
1592 }
1593
1594 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1595   auto cert = getFileAsBuf(kTestCert);
1596   auto key = getFileAsBuf(kTestKey);
1597
1598   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1599   BIO_write(certBio.get(), cert.data(), cert.size());
1600   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1601   BIO_write(keyBio.get(), key.data(), key.size());
1602
1603   // Create SSL structs from buffers to get properties
1604   ssl::X509UniquePtr certStruct(
1605       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1606   ssl::EvpPkeyUniquePtr keyStruct(
1607       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1608   certBio = nullptr;
1609   keyBio = nullptr;
1610
1611   auto origCommonName = getCommonName(certStruct.get());
1612   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1613   certStruct = nullptr;
1614   keyStruct = nullptr;
1615
1616   auto ctx = std::make_shared<SSLContext>();
1617   ctx->loadPrivateKeyFromBufferPEM(key);
1618   ctx->loadCertificateFromBufferPEM(cert);
1619   ctx->loadTrustedCertificates(kTestCA);
1620
1621   ssl::SSLUniquePtr ssl(ctx->createSSL());
1622
1623   auto newCert = SSL_get_certificate(ssl.get());
1624   auto newKey = SSL_get_privatekey(ssl.get());
1625
1626   // Get properties from SSL struct
1627   auto newCommonName = getCommonName(newCert);
1628   auto newKeySize = EVP_PKEY_bits(newKey);
1629
1630   // Check that the key and cert have the expected properties
1631   EXPECT_EQ(origCommonName, newCommonName);
1632   EXPECT_EQ(origKeySize, newKeySize);
1633 }
1634
1635 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1636   EventBase eb;
1637
1638   // Set up SSL context.
1639   auto sslContext = std::make_shared<SSLContext>();
1640   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1641
1642   // create SSL socket
1643   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1644
1645   EXPECT_EQ(1500, socket->getMinWriteSize());
1646
1647   socket->setMinWriteSize(0);
1648   EXPECT_EQ(0, socket->getMinWriteSize());
1649   socket->setMinWriteSize(50000);
1650   EXPECT_EQ(50000, socket->getMinWriteSize());
1651 }
1652
1653 class ReadCallbackTerminator : public ReadCallback {
1654  public:
1655   ReadCallbackTerminator(EventBase* base, WriteCallbackBase* wcb)
1656       : ReadCallback(wcb), base_(base) {}
1657
1658   // Do not write data back, terminate the loop.
1659   void readDataAvailable(size_t len) noexcept override {
1660     std::cerr << "readDataAvailable, len " << len << std::endl;
1661
1662     currentBuffer.length = len;
1663
1664     buffers.push_back(currentBuffer);
1665     currentBuffer.reset();
1666     state = STATE_SUCCEEDED;
1667
1668     socket_->setReadCB(nullptr);
1669     base_->terminateLoopSoon();
1670   }
1671
1672  private:
1673   EventBase* base_;
1674 };
1675
1676 /**
1677  * Test a full unencrypted codepath
1678  */
1679 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1680   EventBase base;
1681
1682   auto clientCtx = std::make_shared<folly::SSLContext>();
1683   auto serverCtx = std::make_shared<folly::SSLContext>();
1684   int fds[2];
1685   getfds(fds);
1686   getctx(clientCtx, serverCtx);
1687   auto client =
1688       AsyncSSLSocket::newSocket(clientCtx, &base, fds[0], false, true);
1689   auto server = AsyncSSLSocket::newSocket(serverCtx, &base, fds[1], true, true);
1690
1691   ReadCallbackTerminator readCallback(&base, nullptr);
1692   server->setReadCB(&readCallback);
1693   readCallback.setSocket(server);
1694
1695   uint8_t buf[128];
1696   memset(buf, 'a', sizeof(buf));
1697   client->write(nullptr, buf, sizeof(buf));
1698
1699   // Check that bytes are unencrypted
1700   char c;
1701   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1702   EXPECT_EQ('a', c);
1703
1704   EventBaseAborter eba(&base, 3000);
1705   base.loop();
1706
1707   EXPECT_EQ(1, readCallback.buffers.size());
1708   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1709
1710   server->setReadCB(&readCallback);
1711
1712   // Unencrypted
1713   server->sslAccept(nullptr);
1714   client->sslConn(nullptr);
1715
1716   // Do NOT wait for handshake, writing should be queued and happen after
1717
1718   client->write(nullptr, buf, sizeof(buf));
1719
1720   // Check that bytes are *not* unencrypted
1721   char c2;
1722   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1723   EXPECT_NE('a', c2);
1724
1725   base.loop();
1726
1727   EXPECT_EQ(2, readCallback.buffers.size());
1728   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1729 }
1730
1731 TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) {
1732   auto clientCtx = std::make_shared<folly::SSLContext>();
1733   auto serverCtx = std::make_shared<folly::SSLContext>();
1734   getctx(clientCtx, serverCtx);
1735
1736   WriteCallbackBase writeCallback;
1737   ReadCallback readCallback(&writeCallback);
1738   HandshakeCallback handshakeCallback(&readCallback);
1739   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1740   TestSSLServer server(&acceptCallback);
1741
1742   EventBase evb;
1743   std::shared_ptr<AsyncSSLSocket> socket =
1744       AsyncSSLSocket::newSocket(clientCtx, &evb, true);
1745   socket->connect(nullptr, server.getAddress(), 0);
1746
1747   evb.loop();
1748
1749   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, socket->getSSLState());
1750   socket->sslConn(nullptr);
1751   evb.loop();
1752   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, socket->getSSLState());
1753
1754   // write()
1755   std::array<uint8_t, 128> buf;
1756   memset(buf.data(), 'a', buf.size());
1757   socket->write(nullptr, buf.data(), buf.size());
1758
1759   socket->close();
1760 }
1761
1762 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1763   // Start listening on a local port
1764   WriteCallbackBase writeCallback;
1765   WriteErrorCallback readCallback(&writeCallback);
1766   HandshakeCallback handshakeCallback(
1767       &readCallback, HandshakeCallback::EXPECT_ERROR);
1768   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1769   TestSSLServer server(&acceptCallback);
1770
1771   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1772   socket->open();
1773   uint8_t buf[3] = {0x16, 0x03, 0x01};
1774   socket->write(buf, sizeof(buf));
1775   socket->closeWithReset();
1776
1777   handshakeCallback.waitForHandshake();
1778   EXPECT_NE(
1779       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1780   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1781 }
1782
1783 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1784   // Start listening on a local port
1785   WriteCallbackBase writeCallback;
1786   WriteErrorCallback readCallback(&writeCallback);
1787   HandshakeCallback handshakeCallback(
1788       &readCallback, HandshakeCallback::EXPECT_ERROR);
1789   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1790   TestSSLServer server(&acceptCallback);
1791
1792   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1793   socket->open();
1794   uint8_t buf[3] = {0x16, 0x03, 0x01};
1795   socket->write(buf, sizeof(buf));
1796   socket->close();
1797
1798   handshakeCallback.waitForHandshake();
1799 #if FOLLY_OPENSSL_IS_110
1800   EXPECT_NE(
1801       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1802 #else
1803   EXPECT_NE(
1804       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1805 #endif
1806 }
1807
1808 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1809   // Start listening on a local port
1810   WriteCallbackBase writeCallback;
1811   WriteErrorCallback readCallback(&writeCallback);
1812   HandshakeCallback handshakeCallback(
1813       &readCallback, HandshakeCallback::EXPECT_ERROR);
1814   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1815   TestSSLServer server(&acceptCallback);
1816
1817   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1818   socket->open();
1819   uint8_t buf[256] = {0x16, 0x03};
1820   memset(buf + 2, 'a', sizeof(buf) - 2);
1821   socket->write(buf, sizeof(buf));
1822   socket->close();
1823
1824   handshakeCallback.waitForHandshake();
1825   EXPECT_NE(
1826       handshakeCallback.errorString_.find("SSL routines"), std::string::npos);
1827 #if defined(OPENSSL_IS_BORINGSSL)
1828   EXPECT_NE(
1829       handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
1830       std::string::npos);
1831 #elif FOLLY_OPENSSL_IS_110
1832   EXPECT_NE(
1833       handshakeCallback.errorString_.find("packet length too long"),
1834       std::string::npos);
1835 #else
1836   EXPECT_NE(
1837       handshakeCallback.errorString_.find("unknown protocol"),
1838       std::string::npos);
1839 #endif
1840 }
1841
1842 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
1843   using folly::ssl::OpenSSLUtils;
1844   EXPECT_EQ(
1845       OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
1846   // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
1847   EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
1848   // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
1849   EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
1850 }
1851
1852 #if FOLLY_ALLOW_TFO
1853
1854 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1855  public:
1856   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1857
1858   explicit MockAsyncTFOSSLSocket(
1859       std::shared_ptr<folly::SSLContext> sslCtx,
1860       EventBase* evb)
1861       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1862
1863   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1864 };
1865
1866 /**
1867  * Test connecting to, writing to, reading from, and closing the
1868  * connection to the SSL server with TFO.
1869  */
1870 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1871   // Start listening on a local port
1872   WriteCallbackBase writeCallback;
1873   ReadCallback readCallback(&writeCallback);
1874   HandshakeCallback handshakeCallback(&readCallback);
1875   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1876   TestSSLServer server(&acceptCallback, true);
1877
1878   // Set up SSL context.
1879   auto sslContext = std::make_shared<SSLContext>();
1880
1881   // connect
1882   auto socket =
1883       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1884   socket->enableTFO();
1885   socket->open();
1886
1887   // write()
1888   std::array<uint8_t, 128> buf;
1889   memset(buf.data(), 'a', buf.size());
1890   socket->write(buf.data(), buf.size());
1891
1892   // read()
1893   std::array<uint8_t, 128> readbuf;
1894   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1895   EXPECT_EQ(bytesRead, 128);
1896   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1897
1898   // close()
1899   socket->close();
1900 }
1901
1902 /**
1903  * Test connecting to, writing to, reading from, and closing the
1904  * connection to the SSL server with TFO.
1905  */
1906 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1907   // Start listening on a local port
1908   WriteCallbackBase writeCallback;
1909   ReadCallback readCallback(&writeCallback);
1910   HandshakeCallback handshakeCallback(&readCallback);
1911   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1912   TestSSLServer server(&acceptCallback, false);
1913
1914   // Set up SSL context.
1915   auto sslContext = std::make_shared<SSLContext>();
1916
1917   // connect
1918   auto socket =
1919       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1920   socket->enableTFO();
1921   socket->open();
1922
1923   // write()
1924   std::array<uint8_t, 128> buf;
1925   memset(buf.data(), 'a', buf.size());
1926   socket->write(buf.data(), buf.size());
1927
1928   // read()
1929   std::array<uint8_t, 128> readbuf;
1930   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1931   EXPECT_EQ(bytesRead, 128);
1932   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1933
1934   // close()
1935   socket->close();
1936 }
1937
1938 class ConnCallback : public AsyncSocket::ConnectCallback {
1939  public:
1940   void connectSuccess() noexcept override {
1941     state = State::SUCCESS;
1942   }
1943
1944   void connectErr(const AsyncSocketException& ex) noexcept override {
1945     state = State::ERROR;
1946     error = ex.what();
1947   }
1948
1949   enum class State { WAITING, SUCCESS, ERROR };
1950
1951   State state{State::WAITING};
1952   std::string error;
1953 };
1954
1955 template <class Cardinality>
1956 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1957     EventBase* evb,
1958     const SocketAddress& address,
1959     Cardinality cardinality) {
1960   // Set up SSL context.
1961   auto sslContext = std::make_shared<SSLContext>();
1962
1963   // connect
1964   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1965       new MockAsyncTFOSSLSocket(sslContext, evb));
1966   socket->enableTFO();
1967
1968   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1969       .Times(cardinality)
1970       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1971         sockaddr_storage addr;
1972         auto len = address.getAddress(&addr);
1973         return connect(fd, (const struct sockaddr*)&addr, len);
1974       }));
1975   return socket;
1976 }
1977
1978 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1979   // Start listening on a local port
1980   WriteCallbackBase writeCallback;
1981   ReadCallback readCallback(&writeCallback);
1982   HandshakeCallback handshakeCallback(&readCallback);
1983   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1984   TestSSLServer server(&acceptCallback, true);
1985
1986   EventBase evb;
1987
1988   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1989   ConnCallback ccb;
1990   socket->connect(&ccb, server.getAddress(), 30);
1991
1992   evb.loop();
1993   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1994
1995   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1996   evb.loop();
1997
1998   BlockingSocket sock(std::move(socket));
1999   // write()
2000   std::array<uint8_t, 128> buf;
2001   memset(buf.data(), 'a', buf.size());
2002   sock.write(buf.data(), buf.size());
2003
2004   // read()
2005   std::array<uint8_t, 128> readbuf;
2006   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
2007   EXPECT_EQ(bytesRead, 128);
2008   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
2009
2010   // close()
2011   sock.close();
2012 }
2013
2014 #if !defined(OPENSSL_IS_BORINGSSL)
2015 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
2016   // Start listening on a local port
2017   ConnectTimeoutCallback acceptCallback;
2018   TestSSLServer server(&acceptCallback, true);
2019
2020   // Set up SSL context.
2021   auto sslContext = std::make_shared<SSLContext>();
2022
2023   // connect
2024   auto socket =
2025       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2026   socket->enableTFO();
2027   EXPECT_THROW(
2028       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
2029 }
2030 #endif
2031
2032 #if !defined(OPENSSL_IS_BORINGSSL)
2033 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
2034   // Start listening on a local port
2035   ConnectTimeoutCallback acceptCallback;
2036   TestSSLServer server(&acceptCallback, true);
2037
2038   EventBase evb;
2039
2040   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
2041   ConnCallback ccb;
2042   // Set a short timeout
2043   socket->connect(&ccb, server.getAddress(), 1);
2044
2045   evb.loop();
2046   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2047 }
2048 #endif
2049
2050 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
2051   // Start listening on a local port
2052   EmptyReadCallback readCallback;
2053   HandshakeCallback handshakeCallback(
2054       &readCallback, HandshakeCallback::EXPECT_ERROR);
2055   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
2056   TestSSLServer server(&acceptCallback, true);
2057
2058   EventBase evb;
2059
2060   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
2061   ConnCallback ccb;
2062   socket->connect(&ccb, server.getAddress(), 100);
2063
2064   evb.loop();
2065   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2066   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
2067 }
2068
2069 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
2070   // Start listening on a local port
2071   EventBase evb;
2072
2073   // Hopefully nothing is listening on this address
2074   SocketAddress addr("127.0.0.1", 65535);
2075   auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
2076   ConnCallback ccb;
2077   socket->connect(&ccb, addr, 100);
2078
2079   evb.loop();
2080   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2081   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
2082 }
2083
2084 TEST(AsyncSSLSocketTest, TestPreReceivedData) {
2085   EventBase clientEventBase;
2086   EventBase serverEventBase;
2087   auto clientCtx = std::make_shared<SSLContext>();
2088   auto dfServerCtx = std::make_shared<SSLContext>();
2089   std::array<int, 2> fds;
2090   getfds(fds.data());
2091   getctx(clientCtx, dfServerCtx);
2092
2093   AsyncSSLSocket::UniquePtr clientSockPtr(
2094       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
2095   AsyncSSLSocket::UniquePtr serverSockPtr(
2096       new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
2097   auto clientSock = clientSockPtr.get();
2098   auto serverSock = serverSockPtr.get();
2099   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
2100
2101   // Steal some data from the server.
2102   clientEventBase.loopOnce();
2103   std::array<uint8_t, 10> buf;
2104   recv(fds[1], buf.data(), buf.size(), 0);
2105
2106   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
2107   SSLHandshakeServer server(std::move(serverSockPtr), true, true);
2108   while (!client.handshakeSuccess_ && !client.handshakeError_) {
2109     serverEventBase.loopOnce();
2110     clientEventBase.loopOnce();
2111   }
2112
2113   EXPECT_TRUE(client.handshakeSuccess_);
2114   EXPECT_TRUE(server.handshakeSuccess_);
2115   EXPECT_EQ(
2116       serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
2117 }
2118
2119 TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) {
2120   EventBase clientEventBase;
2121   EventBase serverEventBase;
2122   auto clientCtx = std::make_shared<SSLContext>();
2123   auto dfServerCtx = std::make_shared<SSLContext>();
2124   std::array<int, 2> fds;
2125   getfds(fds.data());
2126   getctx(clientCtx, dfServerCtx);
2127
2128   AsyncSSLSocket::UniquePtr clientSockPtr(
2129       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
2130   AsyncSocket::UniquePtr serverSockPtr(
2131       new AsyncSocket(&serverEventBase, fds[1]));
2132   auto clientSock = clientSockPtr.get();
2133   auto serverSock = serverSockPtr.get();
2134   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
2135
2136   // Steal some data from the server.
2137   clientEventBase.loopOnce();
2138   std::array<uint8_t, 10> buf;
2139   recv(fds[1], buf.data(), buf.size(), 0);
2140   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
2141   AsyncSSLSocket::UniquePtr serverSSLSockPtr(
2142       new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true));
2143   auto serverSSLSock = serverSSLSockPtr.get();
2144   SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true);
2145   while (!client.handshakeSuccess_ && !client.handshakeError_) {
2146     serverEventBase.loopOnce();
2147     clientEventBase.loopOnce();
2148   }
2149
2150   EXPECT_TRUE(client.handshakeSuccess_);
2151   EXPECT_TRUE(server.handshakeSuccess_);
2152   EXPECT_EQ(
2153       serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
2154 }
2155
2156 /**
2157  * Test overriding the flags passed to "sendmsg()" system call,
2158  * and verifying that write requests fail properly.
2159  */
2160 TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
2161   // Start listening on a local port
2162   SendMsgFlagsCallback msgCallback;
2163   ExpectWriteErrorCallback writeCallback(&msgCallback);
2164   ReadCallback readCallback(&writeCallback);
2165   HandshakeCallback handshakeCallback(&readCallback);
2166   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2167   TestSSLServer server(&acceptCallback);
2168
2169   // Set up SSL context.
2170   auto sslContext = std::make_shared<SSLContext>();
2171   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2172
2173   // connect
2174   auto socket =
2175       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2176   socket->open();
2177
2178   // Setting flags to "-1" to trigger "Invalid argument" error
2179   // on attempt to use this flags in sendmsg() system call.
2180   msgCallback.resetFlags(-1);
2181
2182   // write()
2183   std::vector<uint8_t> buf(128, 'a');
2184   ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
2185
2186   // close()
2187   socket->close();
2188
2189   cerr << "SendMsgParamsCallback test completed" << endl;
2190 }
2191
2192 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
2193 /**
2194  * Test connecting to, writing to, reading from, and closing the
2195  * connection to the SSL server.
2196  */
2197 TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
2198   // This test requires Linux kernel v4.6 or later
2199   struct utsname s_uname;
2200   memset(&s_uname, 0, sizeof(s_uname));
2201   ASSERT_EQ(uname(&s_uname), 0);
2202   int major, minor;
2203   folly::StringPiece extra;
2204   if (folly::split<false>(
2205           '.', std::string(s_uname.release) + ".", major, minor, extra)) {
2206     if (major < 4 || (major == 4 && minor < 6)) {
2207       LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
2208                 << "kernel ver. " << s_uname.release << " detected).";
2209       return;
2210     }
2211   }
2212
2213   // Start listening on a local port
2214   SendMsgDataCallback msgCallback;
2215   WriteCheckTimestampCallback writeCallback(&msgCallback);
2216   ReadCallback readCallback(&writeCallback);
2217   HandshakeCallback handshakeCallback(&readCallback);
2218   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2219   TestSSLServer server(&acceptCallback);
2220
2221   // Set up SSL context.
2222   auto sslContext = std::make_shared<SSLContext>();
2223   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2224
2225   // connect
2226   auto socket =
2227       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2228   socket->open();
2229
2230   // Adding MSG_EOR flag to the message flags - it'll trigger
2231   // timestamp generation for the last byte of the message.
2232   msgCallback.resetFlags(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_EOR);
2233
2234   // Init ancillary data buffer to trigger timestamp notification
2235   union {
2236     uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
2237     struct cmsghdr cmsg;
2238   } u;
2239   u.cmsg.cmsg_level = SOL_SOCKET;
2240   u.cmsg.cmsg_type = SO_TIMESTAMPING;
2241   u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
2242   uint32_t flags = SOF_TIMESTAMPING_TX_SCHED | SOF_TIMESTAMPING_TX_SOFTWARE |
2243       SOF_TIMESTAMPING_TX_ACK;
2244   memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
2245   std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
2246   memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
2247   msgCallback.resetData(std::move(ctrl));
2248
2249   // write()
2250   std::vector<uint8_t> buf(128, 'a');
2251   socket->write(buf.data(), buf.size());
2252
2253   // read()
2254   std::vector<uint8_t> readbuf(buf.size());
2255   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2256   EXPECT_EQ(bytesRead, buf.size());
2257   EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
2258
2259   writeCallback.checkForTimestampNotifications();
2260
2261   // close()
2262   socket->close();
2263
2264   cerr << "SendMsgDataCallback test completed" << endl;
2265 }
2266 #endif // FOLLY_HAVE_MSG_ERRQUEUE
2267
2268 #endif
2269
2270 } // namespace folly
2271
2272 #ifdef SIGPIPE
2273 ///////////////////////////////////////////////////////////////////////////
2274 // init_unit_test_suite
2275 ///////////////////////////////////////////////////////////////////////////
2276 namespace {
2277 struct Initializer {
2278   Initializer() {
2279     signal(SIGPIPE, SIG_IGN);
2280   }
2281 };
2282 Initializer initializer;
2283 } // namespace
2284 #endif