Clang-format AsyncSSLSocketTest.cpp.
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <folly/SocketAddress.h>
19 #include <folly/io/Cursor.h>
20 #include <folly/io/async/AsyncSSLSocket.h>
21 #include <folly/io/async/EventBase.h>
22 #include <folly/portability/GMock.h>
23 #include <folly/portability/GTest.h>
24 #include <folly/portability/OpenSSL.h>
25 #include <folly/portability/Sockets.h>
26 #include <folly/portability/Unistd.h>
27
28 #include <folly/io/async/test/BlockingSocket.h>
29
30 #include <fcntl.h>
31 #include <signal.h>
32 #include <sys/types.h>
33
34 #include <fstream>
35 #include <iostream>
36 #include <list>
37 #include <set>
38 #include <thread>
39
40 #ifdef MSG_ERRQUEUE
41 #include <sys/utsname.h>
42 #endif
43
44 using std::string;
45 using std::vector;
46 using std::min;
47 using std::cerr;
48 using std::endl;
49 using std::list;
50
51 using namespace testing;
52
53 namespace folly {
54 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
55 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
56 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
57
58 constexpr size_t SSLClient::kMaxReadBufferSz;
59 constexpr size_t SSLClient::kMaxReadsPerEvent;
60
61 void getfds(int fds[2]) {
62   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
63     FAIL() << "failed to create socketpair: " << strerror(errno);
64   }
65   for (int idx = 0; idx < 2; ++idx) {
66     int flags = fcntl(fds[idx], F_GETFL, 0);
67     if (flags == -1) {
68       FAIL() << "failed to get flags for socket " << idx << ": "
69              << strerror(errno);
70     }
71     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
72       FAIL() << "failed to put socket " << idx
73              << " in non-blocking mode: " << strerror(errno);
74     }
75   }
76 }
77
78 void getctx(
79     std::shared_ptr<folly::SSLContext> clientCtx,
80     std::shared_ptr<folly::SSLContext> serverCtx) {
81   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
82
83   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
84   serverCtx->loadCertificate(kTestCert);
85   serverCtx->loadPrivateKey(kTestKey);
86 }
87
88 void sslsocketpair(
89     EventBase* eventBase,
90     AsyncSSLSocket::UniquePtr* clientSock,
91     AsyncSSLSocket::UniquePtr* serverSock) {
92   auto clientCtx = std::make_shared<folly::SSLContext>();
93   auto serverCtx = std::make_shared<folly::SSLContext>();
94   int fds[2];
95   getfds(fds);
96   getctx(clientCtx, serverCtx);
97   clientSock->reset(new AsyncSSLSocket(clientCtx, eventBase, fds[0], false));
98   serverSock->reset(new AsyncSSLSocket(serverCtx, eventBase, fds[1], true));
99
100   // (*clientSock)->setSendTimeout(100);
101   // (*serverSock)->setSendTimeout(100);
102 }
103
104 // client protocol filters
105 bool clientProtoFilterPickPony(
106     unsigned char** client,
107     unsigned int* client_len,
108     const unsigned char*,
109     unsigned int) {
110   // the protocol string in length prefixed byte string. the
111   // length byte is not included in the length
112   static unsigned char p[7] = {6, 'p', 'o', 'n', 'i', 'e', 's'};
113   *client = p;
114   *client_len = 7;
115   return true;
116 }
117
118 bool clientProtoFilterPickNone(
119     unsigned char**,
120     unsigned int*,
121     const unsigned char*,
122     unsigned int) {
123   return false;
124 }
125
126 std::string getFileAsBuf(const char* fileName) {
127   std::string buffer;
128   folly::readFile(fileName, buffer);
129   return buffer;
130 }
131
132 std::string getCommonName(X509* cert) {
133   X509_NAME* subject = X509_get_subject_name(cert);
134   std::string cn;
135   cn.resize(ub_common_name);
136   X509_NAME_get_text_by_NID(
137       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
138   return cn;
139 }
140
141 /**
142  * Test connecting to, writing to, reading from, and closing the
143  * connection to the SSL server.
144  */
145 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
146   // Start listening on a local port
147   WriteCallbackBase writeCallback;
148   ReadCallback readCallback(&writeCallback);
149   HandshakeCallback handshakeCallback(&readCallback);
150   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
151   TestSSLServer server(&acceptCallback);
152
153   // Set up SSL context.
154   std::shared_ptr<SSLContext> sslContext(new SSLContext());
155   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
156   // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
157   // sslContext->authenticate(true, false);
158
159   // connect
160   auto socket =
161       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
162   socket->open(std::chrono::milliseconds(10000));
163
164   // write()
165   uint8_t buf[128];
166   memset(buf, 'a', sizeof(buf));
167   socket->write(buf, sizeof(buf));
168
169   // read()
170   uint8_t readbuf[128];
171   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
172   EXPECT_EQ(bytesRead, 128);
173   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
174
175   // close()
176   socket->close();
177
178   cerr << "ConnectWriteReadClose test completed" << endl;
179   EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
180 }
181
182 /**
183  * Test reading after server close.
184  */
185 TEST(AsyncSSLSocketTest, ReadAfterClose) {
186   // Start listening on a local port
187   WriteCallbackBase writeCallback;
188   ReadEOFCallback readCallback(&writeCallback);
189   HandshakeCallback handshakeCallback(&readCallback);
190   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
191   auto server = std::make_unique<TestSSLServer>(&acceptCallback);
192
193   // Set up SSL context.
194   auto sslContext = std::make_shared<SSLContext>();
195   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
196
197   auto socket =
198       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
199   socket->open();
200
201   // This should trigger an EOF on the client.
202   auto evb = handshakeCallback.getSocket()->getEventBase();
203   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
204   std::array<uint8_t, 128> readbuf;
205   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
206   EXPECT_EQ(0, bytesRead);
207 }
208
209 /**
210  * Test bad renegotiation
211  */
212 #if !defined(OPENSSL_IS_BORINGSSL)
213 TEST(AsyncSSLSocketTest, Renegotiate) {
214   EventBase eventBase;
215   auto clientCtx = std::make_shared<SSLContext>();
216   auto dfServerCtx = std::make_shared<SSLContext>();
217   std::array<int, 2> fds;
218   getfds(fds.data());
219   getctx(clientCtx, dfServerCtx);
220
221   AsyncSSLSocket::UniquePtr clientSock(
222       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
223   AsyncSSLSocket::UniquePtr serverSock(
224       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
225   SSLHandshakeClient client(std::move(clientSock), true, true);
226   RenegotiatingServer server(std::move(serverSock));
227
228   while (!client.handshakeSuccess_ && !client.handshakeError_) {
229     eventBase.loopOnce();
230   }
231
232   ASSERT_TRUE(client.handshakeSuccess_);
233
234   auto sslSock = std::move(client).moveSocket();
235   sslSock->detachEventBase();
236   // This is nasty, however we don't want to add support for
237   // renegotiation in AsyncSSLSocket.
238   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
239
240   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
241
242   std::thread t([&]() { eventBase.loopForever(); });
243
244   // Trigger the renegotiation.
245   std::array<uint8_t, 128> buf;
246   memset(buf.data(), 'a', buf.size());
247   try {
248     socket->write(buf.data(), buf.size());
249   } catch (AsyncSocketException& e) {
250     LOG(INFO) << "client got error " << e.what();
251   }
252   eventBase.terminateLoopSoon();
253   t.join();
254
255   eventBase.loop();
256   ASSERT_TRUE(server.renegotiationError_);
257 }
258 #endif
259
260 /**
261  * Negative test for handshakeError().
262  */
263 TEST(AsyncSSLSocketTest, HandshakeError) {
264   // Start listening on a local port
265   WriteCallbackBase writeCallback;
266   WriteErrorCallback readCallback(&writeCallback);
267   HandshakeCallback handshakeCallback(&readCallback);
268   HandshakeErrorCallback acceptCallback(&handshakeCallback);
269   TestSSLServer server(&acceptCallback);
270
271   // Set up SSL context.
272   std::shared_ptr<SSLContext> sslContext(new SSLContext());
273   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
274
275   // connect
276   auto socket =
277       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
278   // read()
279   bool ex = false;
280   try {
281     socket->open();
282
283     uint8_t readbuf[128];
284     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
285     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
286   } catch (AsyncSocketException&) {
287     ex = true;
288   }
289   EXPECT_TRUE(ex);
290
291   // close()
292   socket->close();
293   cerr << "HandshakeError test completed" << endl;
294 }
295
296 /**
297  * Negative test for readError().
298  */
299 TEST(AsyncSSLSocketTest, ReadError) {
300   // Start listening on a local port
301   WriteCallbackBase writeCallback;
302   ReadErrorCallback readCallback(&writeCallback);
303   HandshakeCallback handshakeCallback(&readCallback);
304   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
305   TestSSLServer server(&acceptCallback);
306
307   // Set up SSL context.
308   std::shared_ptr<SSLContext> sslContext(new SSLContext());
309   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
310
311   // connect
312   auto socket =
313       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
314   socket->open();
315
316   // write something to trigger ssl handshake
317   uint8_t buf[128];
318   memset(buf, 'a', sizeof(buf));
319   socket->write(buf, sizeof(buf));
320
321   socket->close();
322   cerr << "ReadError test completed" << endl;
323 }
324
325 /**
326  * Negative test for writeError().
327  */
328 TEST(AsyncSSLSocketTest, WriteError) {
329   // Start listening on a local port
330   WriteCallbackBase writeCallback;
331   WriteErrorCallback readCallback(&writeCallback);
332   HandshakeCallback handshakeCallback(&readCallback);
333   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
334   TestSSLServer server(&acceptCallback);
335
336   // Set up SSL context.
337   std::shared_ptr<SSLContext> sslContext(new SSLContext());
338   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
339
340   // connect
341   auto socket =
342       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
343   socket->open();
344
345   // write something to trigger ssl handshake
346   uint8_t buf[128];
347   memset(buf, 'a', sizeof(buf));
348   socket->write(buf, sizeof(buf));
349
350   socket->close();
351   cerr << "WriteError test completed" << endl;
352 }
353
354 /**
355  * Test a socket with TCP_NODELAY unset.
356  */
357 TEST(AsyncSSLSocketTest, SocketWithDelay) {
358   // Start listening on a local port
359   WriteCallbackBase writeCallback;
360   ReadCallback readCallback(&writeCallback);
361   HandshakeCallback handshakeCallback(&readCallback);
362   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
363   TestSSLServer server(&acceptCallback);
364
365   // Set up SSL context.
366   std::shared_ptr<SSLContext> sslContext(new SSLContext());
367   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
368
369   // connect
370   auto socket =
371       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
372   socket->open();
373
374   // write()
375   uint8_t buf[128];
376   memset(buf, 'a', sizeof(buf));
377   socket->write(buf, sizeof(buf));
378
379   // read()
380   uint8_t readbuf[128];
381   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
382   EXPECT_EQ(bytesRead, 128);
383   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
384
385   // close()
386   socket->close();
387
388   cerr << "SocketWithDelay test completed" << endl;
389 }
390
391 using NextProtocolTypePair =
392     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
393
394 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
395   // For matching protos
396  public:
397   void SetUp() override {
398     getctx(clientCtx, serverCtx);
399   }
400
401   void connect(bool unset = false) {
402     getfds(fds);
403
404     if (unset) {
405       // unsetting NPN for any of [client, server] is enough to make NPN not
406       // work
407       clientCtx->unsetNextProtocols();
408     }
409
410     AsyncSSLSocket::UniquePtr clientSock(
411         new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
412     AsyncSSLSocket::UniquePtr serverSock(
413         new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
414     client = std::make_unique<NpnClient>(std::move(clientSock));
415     server = std::make_unique<NpnServer>(std::move(serverSock));
416
417     eventBase.loop();
418   }
419
420   void expectProtocol(const std::string& proto) {
421     expectHandshakeSuccess();
422     EXPECT_NE(client->nextProtoLength, 0);
423     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
424     EXPECT_EQ(
425         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
426         0);
427     string selected((const char*)client->nextProto, client->nextProtoLength);
428     EXPECT_EQ(proto, selected);
429   }
430
431   void expectNoProtocol() {
432     expectHandshakeSuccess();
433     EXPECT_EQ(client->nextProtoLength, 0);
434     EXPECT_EQ(server->nextProtoLength, 0);
435     EXPECT_EQ(client->nextProto, nullptr);
436     EXPECT_EQ(server->nextProto, nullptr);
437   }
438
439   void expectProtocolType() {
440     expectHandshakeSuccess();
441     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
442         GetParam().second == SSLContext::NextProtocolType::ANY) {
443       EXPECT_EQ(client->protocolType, server->protocolType);
444     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
445                GetParam().second == SSLContext::NextProtocolType::ANY) {
446       // Well not much we can say
447     } else {
448       expectProtocolType(GetParam());
449     }
450   }
451
452   void expectProtocolType(NextProtocolTypePair expected) {
453     expectHandshakeSuccess();
454     EXPECT_EQ(client->protocolType, expected.first);
455     EXPECT_EQ(server->protocolType, expected.second);
456   }
457
458   void expectHandshakeSuccess() {
459     EXPECT_FALSE(client->except.hasValue())
460         << "client handshake error: " << client->except->what();
461     EXPECT_FALSE(server->except.hasValue())
462         << "server handshake error: " << server->except->what();
463   }
464
465   void expectHandshakeError() {
466     EXPECT_TRUE(client->except.hasValue())
467         << "Expected client handshake error!";
468     EXPECT_TRUE(server->except.hasValue())
469         << "Expected server handshake error!";
470   }
471
472   EventBase eventBase;
473   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
474   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
475   int fds[2];
476   std::unique_ptr<NpnClient> client;
477   std::unique_ptr<NpnServer> server;
478 };
479
480 class NextProtocolTLSExtTest : public NextProtocolTest {
481   // For extended TLS protos
482 };
483
484 class NextProtocolNPNOnlyTest : public NextProtocolTest {
485   // For mismatching protos
486 };
487
488 class NextProtocolMismatchTest : public NextProtocolTest {
489   // For mismatching protos
490 };
491
492 TEST_P(NextProtocolTest, NpnTestOverlap) {
493   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
494   serverCtx->setAdvertisedNextProtocols(
495       {"foo", "bar", "baz"}, GetParam().second);
496
497   connect();
498
499   expectProtocol("baz");
500   expectProtocolType();
501 }
502
503 TEST_P(NextProtocolTest, NpnTestUnset) {
504   // Identical to above test, except that we want unset NPN before
505   // looping.
506   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
507   serverCtx->setAdvertisedNextProtocols(
508       {"foo", "bar", "baz"}, GetParam().second);
509
510   connect(true /* unset */);
511
512   // if alpn negotiation fails, type will appear as npn
513   expectNoProtocol();
514   EXPECT_EQ(client->protocolType, server->protocolType);
515 }
516
517 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
518   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
519   serverCtx->setAdvertisedNextProtocols(
520       {"foo", "bar", "baz"}, GetParam().second);
521
522   connect();
523
524   expectNoProtocol();
525   expectProtocolType(
526       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
527 }
528
529 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
530 // will fail on 1.0.2 before that.
531 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
532   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
533   serverCtx->setAdvertisedNextProtocols(
534       {"foo", "bar", "baz"}, GetParam().second);
535   connect();
536
537   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
538       GetParam().second == SSLContext::NextProtocolType::ALPN) {
539     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
540     // mismatch should result in a fatal alert, but this is the current behavior
541     // on all OpenSSL versions/variants, and we want to know if it changes.
542     expectNoProtocol();
543   }
544 #if FOLLY_OPENSSL_IS_110 || defined(OPENSSL_IS_BORINGSSL)
545   else if (
546       GetParam().first == SSLContext::NextProtocolType::ANY &&
547       GetParam().second == SSLContext::NextProtocolType::ANY) {
548 #if FOLLY_OPENSSL_IS_110
549     // OpenSSL 1.1.0 sends a fatal alert on mismatch, which is probavbly the
550     // correct behavior per RFC7301
551     expectHandshakeError();
552 #else
553     // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
554     // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
555     // NPN *and* ALPN
556     expectNoProtocol();
557 #endif
558   }
559 #endif
560   else {
561     expectProtocol("blub");
562     expectProtocolType(
563         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
564   }
565 }
566
567 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
568   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
569   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
570   serverCtx->setAdvertisedNextProtocols(
571       {"foo", "bar", "baz"}, GetParam().second);
572
573   connect();
574
575   expectProtocol("ponies");
576   expectProtocolType();
577 }
578
579 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
580   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
581   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
582   serverCtx->setAdvertisedNextProtocols(
583       {"foo", "bar", "baz"}, GetParam().second);
584
585   connect();
586
587   expectProtocol("blub");
588   expectProtocolType();
589 }
590
591 TEST_P(NextProtocolTest, RandomizedNpnTest) {
592   // Probability that this test will fail is 2^-64, which could be considered
593   // as negligible.
594   const int kTries = 64;
595
596   clientCtx->setAdvertisedNextProtocols(
597       {"foo", "bar", "baz"}, GetParam().first);
598   serverCtx->setRandomizedAdvertisedNextProtocols(
599       {{1, {"foo"}}, {1, {"bar"}}}, GetParam().second);
600
601   std::set<string> selectedProtocols;
602   for (int i = 0; i < kTries; ++i) {
603     connect();
604
605     EXPECT_NE(client->nextProtoLength, 0);
606     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
607     EXPECT_EQ(
608         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
609         0);
610     string selected((const char*)client->nextProto, client->nextProtoLength);
611     selectedProtocols.insert(selected);
612     expectProtocolType();
613   }
614   EXPECT_EQ(selectedProtocols.size(), 2);
615 }
616
617 INSTANTIATE_TEST_CASE_P(
618     AsyncSSLSocketTest,
619     NextProtocolTest,
620     ::testing::Values(
621         NextProtocolTypePair(
622             SSLContext::NextProtocolType::NPN,
623             SSLContext::NextProtocolType::NPN),
624         NextProtocolTypePair(
625             SSLContext::NextProtocolType::NPN,
626             SSLContext::NextProtocolType::ANY),
627         NextProtocolTypePair(
628             SSLContext::NextProtocolType::ANY,
629             SSLContext::NextProtocolType::ANY)));
630
631 #if FOLLY_OPENSSL_HAS_ALPN
632 INSTANTIATE_TEST_CASE_P(
633     AsyncSSLSocketTest,
634     NextProtocolTLSExtTest,
635     ::testing::Values(
636         NextProtocolTypePair(
637             SSLContext::NextProtocolType::ALPN,
638             SSLContext::NextProtocolType::ALPN),
639         NextProtocolTypePair(
640             SSLContext::NextProtocolType::ALPN,
641             SSLContext::NextProtocolType::ANY),
642         NextProtocolTypePair(
643             SSLContext::NextProtocolType::ANY,
644             SSLContext::NextProtocolType::ALPN)));
645 #endif
646
647 INSTANTIATE_TEST_CASE_P(
648     AsyncSSLSocketTest,
649     NextProtocolNPNOnlyTest,
650     ::testing::Values(NextProtocolTypePair(
651         SSLContext::NextProtocolType::NPN,
652         SSLContext::NextProtocolType::NPN)));
653
654 #if FOLLY_OPENSSL_HAS_ALPN
655 INSTANTIATE_TEST_CASE_P(
656     AsyncSSLSocketTest,
657     NextProtocolMismatchTest,
658     ::testing::Values(
659         NextProtocolTypePair(
660             SSLContext::NextProtocolType::NPN,
661             SSLContext::NextProtocolType::ALPN),
662         NextProtocolTypePair(
663             SSLContext::NextProtocolType::ALPN,
664             SSLContext::NextProtocolType::NPN)));
665 #endif
666
667 #ifndef OPENSSL_NO_TLSEXT
668 /**
669  * 1. Client sends TLSEXT_HOSTNAME in client hello.
670  * 2. Server found a match SSL_CTX and use this SSL_CTX to
671  *    continue the SSL handshake.
672  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
673  */
674 TEST(AsyncSSLSocketTest, SNITestMatch) {
675   EventBase eventBase;
676   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
677   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
678   // Use the same SSLContext to continue the handshake after
679   // tlsext_hostname match.
680   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
681   const std::string serverName("xyz.newdev.facebook.com");
682   int fds[2];
683   getfds(fds);
684   getctx(clientCtx, dfServerCtx);
685
686   AsyncSSLSocket::UniquePtr clientSock(
687       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
688   AsyncSSLSocket::UniquePtr serverSock(
689       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
690   SNIClient client(std::move(clientSock));
691   SNIServer server(
692       std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
693
694   eventBase.loop();
695
696   EXPECT_TRUE(client.serverNameMatch);
697   EXPECT_TRUE(server.serverNameMatch);
698 }
699
700 /**
701  * 1. Client sends TLSEXT_HOSTNAME in client hello.
702  * 2. Server cannot find a matching SSL_CTX and continue to use
703  *    the current SSL_CTX to do the handshake.
704  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
705  */
706 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
707   EventBase eventBase;
708   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
709   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
710   // Use the same SSLContext to continue the handshake after
711   // tlsext_hostname match.
712   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
713   const std::string clientRequestingServerName("foo.com");
714   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
715
716   int fds[2];
717   getfds(fds);
718   getctx(clientCtx, dfServerCtx);
719
720   AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
721       clientCtx, &eventBase, fds[0], clientRequestingServerName));
722   AsyncSSLSocket::UniquePtr serverSock(
723       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
724   SNIClient client(std::move(clientSock));
725   SNIServer server(
726       std::move(serverSock),
727       dfServerCtx,
728       hskServerCtx,
729       serverExpectedServerName);
730
731   eventBase.loop();
732
733   EXPECT_TRUE(!client.serverNameMatch);
734   EXPECT_TRUE(!server.serverNameMatch);
735 }
736 /**
737  * 1. Client sends TLSEXT_HOSTNAME in client hello.
738  * 2. We then change the serverName.
739  * 3. We expect that we get 'false' as the result for serNameMatch.
740  */
741
742 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
743   EventBase eventBase;
744   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
745   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
746   // Use the same SSLContext to continue the handshake after
747   // tlsext_hostname match.
748   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
749   const std::string serverName("xyz.newdev.facebook.com");
750   int fds[2];
751   getfds(fds);
752   getctx(clientCtx, dfServerCtx);
753
754   AsyncSSLSocket::UniquePtr clientSock(
755       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
756   // Change the server name
757   std::string newName("new.com");
758   clientSock->setServerName(newName);
759   AsyncSSLSocket::UniquePtr serverSock(
760       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
761   SNIClient client(std::move(clientSock));
762   SNIServer server(
763       std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
764
765   eventBase.loop();
766
767   EXPECT_TRUE(!client.serverNameMatch);
768 }
769
770 /**
771  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
772  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
773  */
774 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
775   EventBase eventBase;
776   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
777   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
778   // Use the same SSLContext to continue the handshake after
779   // tlsext_hostname match.
780   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
781   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
782
783   int fds[2];
784   getfds(fds);
785   getctx(clientCtx, dfServerCtx);
786
787   AsyncSSLSocket::UniquePtr clientSock(
788       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
789   AsyncSSLSocket::UniquePtr serverSock(
790       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
791   SNIClient client(std::move(clientSock));
792   SNIServer server(
793       std::move(serverSock),
794       dfServerCtx,
795       hskServerCtx,
796       serverExpectedServerName);
797
798   eventBase.loop();
799
800   EXPECT_TRUE(!client.serverNameMatch);
801   EXPECT_TRUE(!server.serverNameMatch);
802 }
803
804 #endif
805 /**
806  * Test SSL client socket
807  */
808 TEST(AsyncSSLSocketTest, SSLClientTest) {
809   // Start listening on a local port
810   WriteCallbackBase writeCallback;
811   ReadCallback readCallback(&writeCallback);
812   HandshakeCallback handshakeCallback(&readCallback);
813   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
814   TestSSLServer server(&acceptCallback);
815
816   // Set up SSL client
817   EventBase eventBase;
818   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
819
820   client->connect();
821   EventBaseAborter eba(&eventBase, 3000);
822   eventBase.loop();
823
824   EXPECT_EQ(client->getMiss(), 1);
825   EXPECT_EQ(client->getHit(), 0);
826
827   cerr << "SSLClientTest test completed" << endl;
828 }
829
830 /**
831  * Test SSL client socket session re-use
832  */
833 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
834   // Start listening on a local port
835   WriteCallbackBase writeCallback;
836   ReadCallback readCallback(&writeCallback);
837   HandshakeCallback handshakeCallback(&readCallback);
838   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
839   TestSSLServer server(&acceptCallback);
840
841   // Set up SSL client
842   EventBase eventBase;
843   auto client =
844       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
845
846   client->connect();
847   EventBaseAborter eba(&eventBase, 3000);
848   eventBase.loop();
849
850   EXPECT_EQ(client->getMiss(), 1);
851   EXPECT_EQ(client->getHit(), 9);
852
853   cerr << "SSLClientTestReuse test completed" << endl;
854 }
855
856 /**
857  * Test SSL client socket timeout
858  */
859 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
860   // Start listening on a local port
861   EmptyReadCallback readCallback;
862   HandshakeCallback handshakeCallback(
863       &readCallback, HandshakeCallback::EXPECT_ERROR);
864   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
865   TestSSLServer server(&acceptCallback);
866
867   // Set up SSL client
868   EventBase eventBase;
869   auto client =
870       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
871   client->connect(true /* write before connect completes */);
872   EventBaseAborter eba(&eventBase, 3000);
873   eventBase.loop();
874
875   usleep(100000);
876   // This is checking that the connectError callback precedes any queued
877   // writeError callbacks.  This matches AsyncSocket's behavior
878   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
879   EXPECT_EQ(client->getErrors(), 1);
880   EXPECT_EQ(client->getMiss(), 0);
881   EXPECT_EQ(client->getHit(), 0);
882
883   cerr << "SSLClientTimeoutTest test completed" << endl;
884 }
885
886 // The next 3 tests need an FB-only extension, and will fail without it
887 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
888 /**
889  * Test SSL server async cache
890  */
891 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
892   // Start listening on a local port
893   WriteCallbackBase writeCallback;
894   ReadCallback readCallback(&writeCallback);
895   HandshakeCallback handshakeCallback(&readCallback);
896   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
897   TestSSLAsyncCacheServer server(&acceptCallback);
898
899   // Set up SSL client
900   EventBase eventBase;
901   auto client =
902       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
903
904   client->connect();
905   EventBaseAborter eba(&eventBase, 3000);
906   eventBase.loop();
907
908   EXPECT_EQ(server.getAsyncCallbacks(), 18);
909   EXPECT_EQ(server.getAsyncLookups(), 9);
910   EXPECT_EQ(client->getMiss(), 10);
911   EXPECT_EQ(client->getHit(), 0);
912
913   cerr << "SSLServerAsyncCacheTest test completed" << endl;
914 }
915
916 /**
917  * Test SSL server accept timeout with cache path
918  */
919 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
920   // Start listening on a local port
921   WriteCallbackBase writeCallback;
922   ReadCallback readCallback(&writeCallback);
923   HandshakeCallback handshakeCallback(&readCallback);
924   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
925   TestSSLAsyncCacheServer server(&acceptCallback);
926
927   // Set up SSL client
928   EventBase eventBase;
929   // only do a TCP connect
930   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
931   sock->connect(nullptr, server.getAddress());
932
933   EmptyReadCallback clientReadCallback;
934   clientReadCallback.tcpSocket_ = sock;
935   sock->setReadCB(&clientReadCallback);
936
937   EventBaseAborter eba(&eventBase, 3000);
938   eventBase.loop();
939
940   EXPECT_EQ(readCallback.state, STATE_WAITING);
941
942   cerr << "SSLServerTimeoutTest test completed" << endl;
943 }
944
945 /**
946  * Test SSL server accept timeout with cache path
947  */
948 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
949   // Start listening on a local port
950   WriteCallbackBase writeCallback;
951   ReadCallback readCallback(&writeCallback);
952   HandshakeCallback handshakeCallback(&readCallback);
953   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
954   TestSSLAsyncCacheServer server(&acceptCallback);
955
956   // Set up SSL client
957   EventBase eventBase;
958   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
959
960   client->connect();
961   EventBaseAborter eba(&eventBase, 3000);
962   eventBase.loop();
963
964   EXPECT_EQ(server.getAsyncCallbacks(), 1);
965   EXPECT_EQ(server.getAsyncLookups(), 1);
966   EXPECT_EQ(client->getErrors(), 1);
967   EXPECT_EQ(client->getMiss(), 1);
968   EXPECT_EQ(client->getHit(), 0);
969
970   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
971 }
972
973 /**
974  * Test SSL server accept timeout with cache path
975  */
976 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
977   // Start listening on a local port
978   WriteCallbackBase writeCallback;
979   ReadCallback readCallback(&writeCallback);
980   HandshakeCallback handshakeCallback(
981       &readCallback, HandshakeCallback::EXPECT_ERROR);
982   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
983   TestSSLAsyncCacheServer server(&acceptCallback, 500);
984
985   // Set up SSL client
986   EventBase eventBase;
987   auto client =
988       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
989
990   client->connect();
991   EventBaseAborter eba(&eventBase, 3000);
992   eventBase.loop();
993
994   server.getEventBase().runInEventBaseThread(
995       [&handshakeCallback] { handshakeCallback.closeSocket(); });
996   // give time for the cache lookup to come back and find it closed
997   handshakeCallback.waitForHandshake();
998
999   EXPECT_EQ(server.getAsyncCallbacks(), 1);
1000   EXPECT_EQ(server.getAsyncLookups(), 1);
1001   EXPECT_EQ(client->getErrors(), 1);
1002   EXPECT_EQ(client->getMiss(), 1);
1003   EXPECT_EQ(client->getHit(), 0);
1004
1005   cerr << "SSLServerCacheCloseTest test completed" << endl;
1006 }
1007 #endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
1008
1009 /**
1010  * Verify Client Ciphers obtained using SSL MSG Callback.
1011  */
1012 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1013   EventBase eventBase;
1014   auto clientCtx = std::make_shared<SSLContext>();
1015   auto serverCtx = std::make_shared<SSLContext>();
1016   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1017   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1018   serverCtx->loadPrivateKey(kTestKey);
1019   serverCtx->loadCertificate(kTestCert);
1020   serverCtx->loadTrustedCertificates(kTestCA);
1021   serverCtx->loadClientCAList(kTestCA);
1022
1023   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1024   clientCtx->ciphers("AES256-SHA:AES128-SHA");
1025   clientCtx->loadPrivateKey(kTestKey);
1026   clientCtx->loadCertificate(kTestCert);
1027   clientCtx->loadTrustedCertificates(kTestCA);
1028
1029   int fds[2];
1030   getfds(fds);
1031
1032   AsyncSSLSocket::UniquePtr clientSock(
1033       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1034   AsyncSSLSocket::UniquePtr serverSock(
1035       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1036
1037   SSLHandshakeClient client(std::move(clientSock), true, true);
1038   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1039
1040   eventBase.loop();
1041
1042 #if defined(OPENSSL_IS_BORINGSSL)
1043   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
1044 #else
1045   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
1046 #endif
1047   EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1048   EXPECT_TRUE(client.handshakeVerify_);
1049   EXPECT_TRUE(client.handshakeSuccess_);
1050   EXPECT_TRUE(!client.handshakeError_);
1051   EXPECT_TRUE(server.handshakeVerify_);
1052   EXPECT_TRUE(server.handshakeSuccess_);
1053   EXPECT_TRUE(!server.handshakeError_);
1054 }
1055
1056 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1057   EventBase eventBase;
1058   auto ctx = std::make_shared<SSLContext>();
1059
1060   int fds[2];
1061   getfds(fds);
1062
1063   int bufLen = 42;
1064   uint8_t majorVersion = 18;
1065   uint8_t minorVersion = 25;
1066
1067   // Create callback buf
1068   auto buf = IOBuf::create(bufLen);
1069   buf->append(bufLen);
1070   folly::io::RWPrivateCursor cursor(buf.get());
1071   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1072   cursor.write<uint16_t>(0);
1073   cursor.write<uint8_t>(38);
1074   cursor.write<uint8_t>(majorVersion);
1075   cursor.write<uint8_t>(minorVersion);
1076   cursor.skip(32);
1077   cursor.write<uint32_t>(0);
1078
1079   SSL* ssl = ctx->createSSL();
1080   SCOPE_EXIT {
1081     SSL_free(ssl);
1082   };
1083   AsyncSSLSocket::UniquePtr sock(
1084       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1085   sock->enableClientHelloParsing();
1086
1087   // Test client hello parsing in one packet
1088   AsyncSSLSocket::clientHelloParsingCallback(
1089       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1090   buf.reset();
1091
1092   auto parsedClientHello = sock->getClientHelloInfo();
1093   EXPECT_TRUE(parsedClientHello != nullptr);
1094   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1095   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1096 }
1097
1098 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1099   EventBase eventBase;
1100   auto ctx = std::make_shared<SSLContext>();
1101
1102   int fds[2];
1103   getfds(fds);
1104
1105   int bufLen = 42;
1106   uint8_t majorVersion = 18;
1107   uint8_t minorVersion = 25;
1108
1109   // Create callback buf
1110   auto buf = IOBuf::create(bufLen);
1111   buf->append(bufLen);
1112   folly::io::RWPrivateCursor cursor(buf.get());
1113   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1114   cursor.write<uint16_t>(0);
1115   cursor.write<uint8_t>(38);
1116   cursor.write<uint8_t>(majorVersion);
1117   cursor.write<uint8_t>(minorVersion);
1118   cursor.skip(32);
1119   cursor.write<uint32_t>(0);
1120
1121   SSL* ssl = ctx->createSSL();
1122   SCOPE_EXIT {
1123     SSL_free(ssl);
1124   };
1125   AsyncSSLSocket::UniquePtr sock(
1126       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1127   sock->enableClientHelloParsing();
1128
1129   // Test parsing with two packets with first packet size < 3
1130   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1131   AsyncSSLSocket::clientHelloParsingCallback(
1132       0,
1133       0,
1134       SSL3_RT_HANDSHAKE,
1135       bufCopy->data(),
1136       bufCopy->length(),
1137       ssl,
1138       sock.get());
1139   bufCopy.reset();
1140   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1141   AsyncSSLSocket::clientHelloParsingCallback(
1142       0,
1143       0,
1144       SSL3_RT_HANDSHAKE,
1145       bufCopy->data(),
1146       bufCopy->length(),
1147       ssl,
1148       sock.get());
1149   bufCopy.reset();
1150
1151   auto parsedClientHello = sock->getClientHelloInfo();
1152   EXPECT_TRUE(parsedClientHello != nullptr);
1153   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1154   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1155 }
1156
1157 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1158   EventBase eventBase;
1159   auto ctx = std::make_shared<SSLContext>();
1160
1161   int fds[2];
1162   getfds(fds);
1163
1164   int bufLen = 42;
1165   uint8_t majorVersion = 18;
1166   uint8_t minorVersion = 25;
1167
1168   // Create callback buf
1169   auto buf = IOBuf::create(bufLen);
1170   buf->append(bufLen);
1171   folly::io::RWPrivateCursor cursor(buf.get());
1172   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1173   cursor.write<uint16_t>(0);
1174   cursor.write<uint8_t>(38);
1175   cursor.write<uint8_t>(majorVersion);
1176   cursor.write<uint8_t>(minorVersion);
1177   cursor.skip(32);
1178   cursor.write<uint32_t>(0);
1179
1180   SSL* ssl = ctx->createSSL();
1181   SCOPE_EXIT {
1182     SSL_free(ssl);
1183   };
1184   AsyncSSLSocket::UniquePtr sock(
1185       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1186   sock->enableClientHelloParsing();
1187
1188   // Test parsing with multiple small packets
1189   for (uint64_t i = 0; i < buf->length(); i += 3) {
1190     auto bufCopy = folly::IOBuf::copyBuffer(
1191         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1192     AsyncSSLSocket::clientHelloParsingCallback(
1193         0,
1194         0,
1195         SSL3_RT_HANDSHAKE,
1196         bufCopy->data(),
1197         bufCopy->length(),
1198         ssl,
1199         sock.get());
1200     bufCopy.reset();
1201   }
1202
1203   auto parsedClientHello = sock->getClientHelloInfo();
1204   EXPECT_TRUE(parsedClientHello != nullptr);
1205   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1206   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1207 }
1208
1209 /**
1210  * Verify sucessful behavior of SSL certificate validation.
1211  */
1212 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1213   EventBase eventBase;
1214   auto clientCtx = std::make_shared<SSLContext>();
1215   auto dfServerCtx = std::make_shared<SSLContext>();
1216
1217   int fds[2];
1218   getfds(fds);
1219   getctx(clientCtx, dfServerCtx);
1220
1221   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1222   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1223
1224   AsyncSSLSocket::UniquePtr clientSock(
1225       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1226   AsyncSSLSocket::UniquePtr serverSock(
1227       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1228
1229   SSLHandshakeClient client(std::move(clientSock), true, true);
1230   clientCtx->loadTrustedCertificates(kTestCA);
1231
1232   SSLHandshakeServer server(std::move(serverSock), true, true);
1233
1234   eventBase.loop();
1235
1236   EXPECT_TRUE(client.handshakeVerify_);
1237   EXPECT_TRUE(client.handshakeSuccess_);
1238   EXPECT_TRUE(!client.handshakeError_);
1239   EXPECT_LE(0, client.handshakeTime.count());
1240   EXPECT_TRUE(!server.handshakeVerify_);
1241   EXPECT_TRUE(server.handshakeSuccess_);
1242   EXPECT_TRUE(!server.handshakeError_);
1243   EXPECT_LE(0, server.handshakeTime.count());
1244 }
1245
1246 /**
1247  * Verify that the client's verification callback is able to fail SSL
1248  * connection establishment.
1249  */
1250 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1251   EventBase eventBase;
1252   auto clientCtx = std::make_shared<SSLContext>();
1253   auto dfServerCtx = std::make_shared<SSLContext>();
1254
1255   int fds[2];
1256   getfds(fds);
1257   getctx(clientCtx, dfServerCtx);
1258
1259   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1260   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1261
1262   AsyncSSLSocket::UniquePtr clientSock(
1263       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1264   AsyncSSLSocket::UniquePtr serverSock(
1265       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1266
1267   SSLHandshakeClient client(std::move(clientSock), true, false);
1268   clientCtx->loadTrustedCertificates(kTestCA);
1269
1270   SSLHandshakeServer server(std::move(serverSock), true, true);
1271
1272   eventBase.loop();
1273
1274   EXPECT_TRUE(client.handshakeVerify_);
1275   EXPECT_TRUE(!client.handshakeSuccess_);
1276   EXPECT_TRUE(client.handshakeError_);
1277   EXPECT_LE(0, client.handshakeTime.count());
1278   EXPECT_TRUE(!server.handshakeVerify_);
1279   EXPECT_TRUE(!server.handshakeSuccess_);
1280   EXPECT_TRUE(server.handshakeError_);
1281   EXPECT_LE(0, server.handshakeTime.count());
1282 }
1283
1284 /**
1285  * Verify that the options in SSLContext can be overridden in
1286  * sslConnect/Accept.i.e specifying that no validation should be performed
1287  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1288  * the validation callback.
1289  */
1290 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1291   EventBase eventBase;
1292   auto clientCtx = std::make_shared<SSLContext>();
1293   auto dfServerCtx = std::make_shared<SSLContext>();
1294
1295   int fds[2];
1296   getfds(fds);
1297   getctx(clientCtx, dfServerCtx);
1298
1299   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1300   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1301
1302   AsyncSSLSocket::UniquePtr clientSock(
1303       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1304   AsyncSSLSocket::UniquePtr serverSock(
1305       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1306
1307   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1308   clientCtx->loadTrustedCertificates(kTestCA);
1309
1310   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1311
1312   eventBase.loop();
1313
1314   EXPECT_TRUE(!client.handshakeVerify_);
1315   EXPECT_TRUE(client.handshakeSuccess_);
1316   EXPECT_TRUE(!client.handshakeError_);
1317   EXPECT_LE(0, client.handshakeTime.count());
1318   EXPECT_TRUE(!server.handshakeVerify_);
1319   EXPECT_TRUE(server.handshakeSuccess_);
1320   EXPECT_TRUE(!server.handshakeError_);
1321   EXPECT_LE(0, server.handshakeTime.count());
1322 }
1323
1324 /**
1325  * Verify that the options in SSLContext can be overridden in
1326  * sslConnect/Accept. Enable verification even if context says otherwise.
1327  * Test requireClientCert with client cert
1328  */
1329 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1330   EventBase eventBase;
1331   auto clientCtx = std::make_shared<SSLContext>();
1332   auto serverCtx = std::make_shared<SSLContext>();
1333   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1334   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1335   serverCtx->loadPrivateKey(kTestKey);
1336   serverCtx->loadCertificate(kTestCert);
1337   serverCtx->loadTrustedCertificates(kTestCA);
1338   serverCtx->loadClientCAList(kTestCA);
1339
1340   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1341   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1342   clientCtx->loadPrivateKey(kTestKey);
1343   clientCtx->loadCertificate(kTestCert);
1344   clientCtx->loadTrustedCertificates(kTestCA);
1345
1346   int fds[2];
1347   getfds(fds);
1348
1349   AsyncSSLSocket::UniquePtr clientSock(
1350       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1351   AsyncSSLSocket::UniquePtr serverSock(
1352       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1353
1354   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1355   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1356
1357   eventBase.loop();
1358
1359   EXPECT_TRUE(client.handshakeVerify_);
1360   EXPECT_TRUE(client.handshakeSuccess_);
1361   EXPECT_FALSE(client.handshakeError_);
1362   EXPECT_LE(0, client.handshakeTime.count());
1363   EXPECT_TRUE(server.handshakeVerify_);
1364   EXPECT_TRUE(server.handshakeSuccess_);
1365   EXPECT_FALSE(server.handshakeError_);
1366   EXPECT_LE(0, server.handshakeTime.count());
1367 }
1368
1369 /**
1370  * Verify that the client's verification callback is able to override
1371  * the preverification failure and allow a successful connection.
1372  */
1373 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1374   EventBase eventBase;
1375   auto clientCtx = std::make_shared<SSLContext>();
1376   auto dfServerCtx = std::make_shared<SSLContext>();
1377
1378   int fds[2];
1379   getfds(fds);
1380   getctx(clientCtx, dfServerCtx);
1381
1382   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1383   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1384
1385   AsyncSSLSocket::UniquePtr clientSock(
1386       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1387   AsyncSSLSocket::UniquePtr serverSock(
1388       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1389
1390   SSLHandshakeClient client(std::move(clientSock), false, true);
1391   SSLHandshakeServer server(std::move(serverSock), true, true);
1392
1393   eventBase.loop();
1394
1395   EXPECT_TRUE(client.handshakeVerify_);
1396   EXPECT_TRUE(client.handshakeSuccess_);
1397   EXPECT_TRUE(!client.handshakeError_);
1398   EXPECT_LE(0, client.handshakeTime.count());
1399   EXPECT_TRUE(!server.handshakeVerify_);
1400   EXPECT_TRUE(server.handshakeSuccess_);
1401   EXPECT_TRUE(!server.handshakeError_);
1402   EXPECT_LE(0, server.handshakeTime.count());
1403 }
1404
1405 /**
1406  * Verify that specifying that no validation should be performed allows an
1407  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1408  * callback.
1409  */
1410 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1411   EventBase eventBase;
1412   auto clientCtx = std::make_shared<SSLContext>();
1413   auto dfServerCtx = std::make_shared<SSLContext>();
1414
1415   int fds[2];
1416   getfds(fds);
1417   getctx(clientCtx, dfServerCtx);
1418
1419   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1420   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1421
1422   AsyncSSLSocket::UniquePtr clientSock(
1423       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1424   AsyncSSLSocket::UniquePtr serverSock(
1425       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1426
1427   SSLHandshakeClient client(std::move(clientSock), false, false);
1428   SSLHandshakeServer server(std::move(serverSock), false, false);
1429
1430   eventBase.loop();
1431
1432   EXPECT_TRUE(!client.handshakeVerify_);
1433   EXPECT_TRUE(client.handshakeSuccess_);
1434   EXPECT_TRUE(!client.handshakeError_);
1435   EXPECT_LE(0, client.handshakeTime.count());
1436   EXPECT_TRUE(!server.handshakeVerify_);
1437   EXPECT_TRUE(server.handshakeSuccess_);
1438   EXPECT_TRUE(!server.handshakeError_);
1439   EXPECT_LE(0, server.handshakeTime.count());
1440 }
1441
1442 /**
1443  * Test requireClientCert with client cert
1444  */
1445 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1446   EventBase eventBase;
1447   auto clientCtx = std::make_shared<SSLContext>();
1448   auto serverCtx = std::make_shared<SSLContext>();
1449   serverCtx->setVerificationOption(
1450       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1451   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1452   serverCtx->loadPrivateKey(kTestKey);
1453   serverCtx->loadCertificate(kTestCert);
1454   serverCtx->loadTrustedCertificates(kTestCA);
1455   serverCtx->loadClientCAList(kTestCA);
1456
1457   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1458   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1459   clientCtx->loadPrivateKey(kTestKey);
1460   clientCtx->loadCertificate(kTestCert);
1461   clientCtx->loadTrustedCertificates(kTestCA);
1462
1463   int fds[2];
1464   getfds(fds);
1465
1466   AsyncSSLSocket::UniquePtr clientSock(
1467       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1468   AsyncSSLSocket::UniquePtr serverSock(
1469       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1470
1471   SSLHandshakeClient client(std::move(clientSock), true, true);
1472   SSLHandshakeServer server(std::move(serverSock), true, true);
1473
1474   eventBase.loop();
1475
1476   EXPECT_TRUE(client.handshakeVerify_);
1477   EXPECT_TRUE(client.handshakeSuccess_);
1478   EXPECT_FALSE(client.handshakeError_);
1479   EXPECT_LE(0, client.handshakeTime.count());
1480   EXPECT_TRUE(server.handshakeVerify_);
1481   EXPECT_TRUE(server.handshakeSuccess_);
1482   EXPECT_FALSE(server.handshakeError_);
1483   EXPECT_LE(0, server.handshakeTime.count());
1484 }
1485
1486 /**
1487  * Test requireClientCert with no client cert
1488  */
1489 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1490   EventBase eventBase;
1491   auto clientCtx = std::make_shared<SSLContext>();
1492   auto serverCtx = std::make_shared<SSLContext>();
1493   serverCtx->setVerificationOption(
1494       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1495   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1496   serverCtx->loadPrivateKey(kTestKey);
1497   serverCtx->loadCertificate(kTestCert);
1498   serverCtx->loadTrustedCertificates(kTestCA);
1499   serverCtx->loadClientCAList(kTestCA);
1500   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1501   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1502
1503   int fds[2];
1504   getfds(fds);
1505
1506   AsyncSSLSocket::UniquePtr clientSock(
1507       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1508   AsyncSSLSocket::UniquePtr serverSock(
1509       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1510
1511   SSLHandshakeClient client(std::move(clientSock), false, false);
1512   SSLHandshakeServer server(std::move(serverSock), false, false);
1513
1514   eventBase.loop();
1515
1516   EXPECT_FALSE(server.handshakeVerify_);
1517   EXPECT_FALSE(server.handshakeSuccess_);
1518   EXPECT_TRUE(server.handshakeError_);
1519   EXPECT_LE(0, client.handshakeTime.count());
1520   EXPECT_LE(0, server.handshakeTime.count());
1521 }
1522
1523 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1524   auto cert = getFileAsBuf(kTestCert);
1525   auto key = getFileAsBuf(kTestKey);
1526
1527   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1528   BIO_write(certBio.get(), cert.data(), cert.size());
1529   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1530   BIO_write(keyBio.get(), key.data(), key.size());
1531
1532   // Create SSL structs from buffers to get properties
1533   ssl::X509UniquePtr certStruct(
1534       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1535   ssl::EvpPkeyUniquePtr keyStruct(
1536       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1537   certBio = nullptr;
1538   keyBio = nullptr;
1539
1540   auto origCommonName = getCommonName(certStruct.get());
1541   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1542   certStruct = nullptr;
1543   keyStruct = nullptr;
1544
1545   auto ctx = std::make_shared<SSLContext>();
1546   ctx->loadPrivateKeyFromBufferPEM(key);
1547   ctx->loadCertificateFromBufferPEM(cert);
1548   ctx->loadTrustedCertificates(kTestCA);
1549
1550   ssl::SSLUniquePtr ssl(ctx->createSSL());
1551
1552   auto newCert = SSL_get_certificate(ssl.get());
1553   auto newKey = SSL_get_privatekey(ssl.get());
1554
1555   // Get properties from SSL struct
1556   auto newCommonName = getCommonName(newCert);
1557   auto newKeySize = EVP_PKEY_bits(newKey);
1558
1559   // Check that the key and cert have the expected properties
1560   EXPECT_EQ(origCommonName, newCommonName);
1561   EXPECT_EQ(origKeySize, newKeySize);
1562 }
1563
1564 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1565   EventBase eb;
1566
1567   // Set up SSL context.
1568   auto sslContext = std::make_shared<SSLContext>();
1569   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1570
1571   // create SSL socket
1572   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1573
1574   EXPECT_EQ(1500, socket->getMinWriteSize());
1575
1576   socket->setMinWriteSize(0);
1577   EXPECT_EQ(0, socket->getMinWriteSize());
1578   socket->setMinWriteSize(50000);
1579   EXPECT_EQ(50000, socket->getMinWriteSize());
1580 }
1581
1582 class ReadCallbackTerminator : public ReadCallback {
1583  public:
1584   ReadCallbackTerminator(EventBase* base, WriteCallbackBase* wcb)
1585       : ReadCallback(wcb), base_(base) {}
1586
1587   // Do not write data back, terminate the loop.
1588   void readDataAvailable(size_t len) noexcept override {
1589     std::cerr << "readDataAvailable, len " << len << std::endl;
1590
1591     currentBuffer.length = len;
1592
1593     buffers.push_back(currentBuffer);
1594     currentBuffer.reset();
1595     state = STATE_SUCCEEDED;
1596
1597     socket_->setReadCB(nullptr);
1598     base_->terminateLoopSoon();
1599   }
1600
1601  private:
1602   EventBase* base_;
1603 };
1604
1605 /**
1606  * Test a full unencrypted codepath
1607  */
1608 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1609   EventBase base;
1610
1611   auto clientCtx = std::make_shared<folly::SSLContext>();
1612   auto serverCtx = std::make_shared<folly::SSLContext>();
1613   int fds[2];
1614   getfds(fds);
1615   getctx(clientCtx, serverCtx);
1616   auto client =
1617       AsyncSSLSocket::newSocket(clientCtx, &base, fds[0], false, true);
1618   auto server = AsyncSSLSocket::newSocket(serverCtx, &base, fds[1], true, true);
1619
1620   ReadCallbackTerminator readCallback(&base, nullptr);
1621   server->setReadCB(&readCallback);
1622   readCallback.setSocket(server);
1623
1624   uint8_t buf[128];
1625   memset(buf, 'a', sizeof(buf));
1626   client->write(nullptr, buf, sizeof(buf));
1627
1628   // Check that bytes are unencrypted
1629   char c;
1630   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1631   EXPECT_EQ('a', c);
1632
1633   EventBaseAborter eba(&base, 3000);
1634   base.loop();
1635
1636   EXPECT_EQ(1, readCallback.buffers.size());
1637   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1638
1639   server->setReadCB(&readCallback);
1640
1641   // Unencrypted
1642   server->sslAccept(nullptr);
1643   client->sslConn(nullptr);
1644
1645   // Do NOT wait for handshake, writing should be queued and happen after
1646
1647   client->write(nullptr, buf, sizeof(buf));
1648
1649   // Check that bytes are *not* unencrypted
1650   char c2;
1651   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1652   EXPECT_NE('a', c2);
1653
1654   base.loop();
1655
1656   EXPECT_EQ(2, readCallback.buffers.size());
1657   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1658 }
1659
1660 TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) {
1661   auto clientCtx = std::make_shared<folly::SSLContext>();
1662   auto serverCtx = std::make_shared<folly::SSLContext>();
1663   getctx(clientCtx, serverCtx);
1664
1665   WriteCallbackBase writeCallback;
1666   ReadCallback readCallback(&writeCallback);
1667   HandshakeCallback handshakeCallback(&readCallback);
1668   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1669   TestSSLServer server(&acceptCallback);
1670
1671   EventBase evb;
1672   std::shared_ptr<AsyncSSLSocket> socket =
1673       AsyncSSLSocket::newSocket(clientCtx, &evb, true);
1674   socket->connect(nullptr, server.getAddress(), 0);
1675
1676   evb.loop();
1677
1678   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, socket->getSSLState());
1679   socket->sslConn(nullptr);
1680   evb.loop();
1681   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, socket->getSSLState());
1682
1683   // write()
1684   std::array<uint8_t, 128> buf;
1685   memset(buf.data(), 'a', buf.size());
1686   socket->write(nullptr, buf.data(), buf.size());
1687
1688   socket->close();
1689 }
1690
1691 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1692   // Start listening on a local port
1693   WriteCallbackBase writeCallback;
1694   WriteErrorCallback readCallback(&writeCallback);
1695   HandshakeCallback handshakeCallback(
1696       &readCallback, HandshakeCallback::EXPECT_ERROR);
1697   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1698   TestSSLServer server(&acceptCallback);
1699
1700   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1701   socket->open();
1702   uint8_t buf[3] = {0x16, 0x03, 0x01};
1703   socket->write(buf, sizeof(buf));
1704   socket->closeWithReset();
1705
1706   handshakeCallback.waitForHandshake();
1707   EXPECT_NE(
1708       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1709   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1710 }
1711
1712 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1713   // Start listening on a local port
1714   WriteCallbackBase writeCallback;
1715   WriteErrorCallback readCallback(&writeCallback);
1716   HandshakeCallback handshakeCallback(
1717       &readCallback, HandshakeCallback::EXPECT_ERROR);
1718   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1719   TestSSLServer server(&acceptCallback);
1720
1721   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1722   socket->open();
1723   uint8_t buf[3] = {0x16, 0x03, 0x01};
1724   socket->write(buf, sizeof(buf));
1725   socket->close();
1726
1727   handshakeCallback.waitForHandshake();
1728 #if FOLLY_OPENSSL_IS_110
1729   EXPECT_NE(
1730       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1731 #else
1732   EXPECT_NE(
1733       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1734 #endif
1735 }
1736
1737 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1738   // Start listening on a local port
1739   WriteCallbackBase writeCallback;
1740   WriteErrorCallback readCallback(&writeCallback);
1741   HandshakeCallback handshakeCallback(
1742       &readCallback, HandshakeCallback::EXPECT_ERROR);
1743   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1744   TestSSLServer server(&acceptCallback);
1745
1746   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1747   socket->open();
1748   uint8_t buf[256] = {0x16, 0x03};
1749   memset(buf + 2, 'a', sizeof(buf) - 2);
1750   socket->write(buf, sizeof(buf));
1751   socket->close();
1752
1753   handshakeCallback.waitForHandshake();
1754   EXPECT_NE(
1755       handshakeCallback.errorString_.find("SSL routines"), std::string::npos);
1756 #if defined(OPENSSL_IS_BORINGSSL)
1757   EXPECT_NE(
1758       handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
1759       std::string::npos);
1760 #elif FOLLY_OPENSSL_IS_110
1761   EXPECT_NE(
1762       handshakeCallback.errorString_.find("packet length too long"),
1763       std::string::npos);
1764 #else
1765   EXPECT_NE(
1766       handshakeCallback.errorString_.find("unknown protocol"),
1767       std::string::npos);
1768 #endif
1769 }
1770
1771 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
1772   using folly::ssl::OpenSSLUtils;
1773   EXPECT_EQ(
1774       OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
1775   // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
1776   EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
1777   // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
1778   EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
1779 }
1780
1781 #if FOLLY_ALLOW_TFO
1782
1783 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1784  public:
1785   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1786
1787   explicit MockAsyncTFOSSLSocket(
1788       std::shared_ptr<folly::SSLContext> sslCtx,
1789       EventBase* evb)
1790       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1791
1792   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1793 };
1794
1795 /**
1796  * Test connecting to, writing to, reading from, and closing the
1797  * connection to the SSL server with TFO.
1798  */
1799 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1800   // Start listening on a local port
1801   WriteCallbackBase writeCallback;
1802   ReadCallback readCallback(&writeCallback);
1803   HandshakeCallback handshakeCallback(&readCallback);
1804   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1805   TestSSLServer server(&acceptCallback, true);
1806
1807   // Set up SSL context.
1808   auto sslContext = std::make_shared<SSLContext>();
1809
1810   // connect
1811   auto socket =
1812       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1813   socket->enableTFO();
1814   socket->open();
1815
1816   // write()
1817   std::array<uint8_t, 128> buf;
1818   memset(buf.data(), 'a', buf.size());
1819   socket->write(buf.data(), buf.size());
1820
1821   // read()
1822   std::array<uint8_t, 128> readbuf;
1823   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1824   EXPECT_EQ(bytesRead, 128);
1825   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1826
1827   // close()
1828   socket->close();
1829 }
1830
1831 /**
1832  * Test connecting to, writing to, reading from, and closing the
1833  * connection to the SSL server with TFO.
1834  */
1835 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1836   // Start listening on a local port
1837   WriteCallbackBase writeCallback;
1838   ReadCallback readCallback(&writeCallback);
1839   HandshakeCallback handshakeCallback(&readCallback);
1840   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1841   TestSSLServer server(&acceptCallback, false);
1842
1843   // Set up SSL context.
1844   auto sslContext = std::make_shared<SSLContext>();
1845
1846   // connect
1847   auto socket =
1848       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1849   socket->enableTFO();
1850   socket->open();
1851
1852   // write()
1853   std::array<uint8_t, 128> buf;
1854   memset(buf.data(), 'a', buf.size());
1855   socket->write(buf.data(), buf.size());
1856
1857   // read()
1858   std::array<uint8_t, 128> readbuf;
1859   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1860   EXPECT_EQ(bytesRead, 128);
1861   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1862
1863   // close()
1864   socket->close();
1865 }
1866
1867 class ConnCallback : public AsyncSocket::ConnectCallback {
1868  public:
1869   void connectSuccess() noexcept override {
1870     state = State::SUCCESS;
1871   }
1872
1873   void connectErr(const AsyncSocketException& ex) noexcept override {
1874     state = State::ERROR;
1875     error = ex.what();
1876   }
1877
1878   enum class State { WAITING, SUCCESS, ERROR };
1879
1880   State state{State::WAITING};
1881   std::string error;
1882 };
1883
1884 template <class Cardinality>
1885 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1886     EventBase* evb,
1887     const SocketAddress& address,
1888     Cardinality cardinality) {
1889   // Set up SSL context.
1890   auto sslContext = std::make_shared<SSLContext>();
1891
1892   // connect
1893   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1894       new MockAsyncTFOSSLSocket(sslContext, evb));
1895   socket->enableTFO();
1896
1897   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1898       .Times(cardinality)
1899       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1900         sockaddr_storage addr;
1901         auto len = address.getAddress(&addr);
1902         return connect(fd, (const struct sockaddr*)&addr, len);
1903       }));
1904   return socket;
1905 }
1906
1907 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1908   // Start listening on a local port
1909   WriteCallbackBase writeCallback;
1910   ReadCallback readCallback(&writeCallback);
1911   HandshakeCallback handshakeCallback(&readCallback);
1912   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1913   TestSSLServer server(&acceptCallback, true);
1914
1915   EventBase evb;
1916
1917   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1918   ConnCallback ccb;
1919   socket->connect(&ccb, server.getAddress(), 30);
1920
1921   evb.loop();
1922   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1923
1924   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1925   evb.loop();
1926
1927   BlockingSocket sock(std::move(socket));
1928   // write()
1929   std::array<uint8_t, 128> buf;
1930   memset(buf.data(), 'a', buf.size());
1931   sock.write(buf.data(), buf.size());
1932
1933   // read()
1934   std::array<uint8_t, 128> readbuf;
1935   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1936   EXPECT_EQ(bytesRead, 128);
1937   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1938
1939   // close()
1940   sock.close();
1941 }
1942
1943 #if !defined(OPENSSL_IS_BORINGSSL)
1944 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1945   // Start listening on a local port
1946   ConnectTimeoutCallback acceptCallback;
1947   TestSSLServer server(&acceptCallback, true);
1948
1949   // Set up SSL context.
1950   auto sslContext = std::make_shared<SSLContext>();
1951
1952   // connect
1953   auto socket =
1954       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1955   socket->enableTFO();
1956   EXPECT_THROW(
1957       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
1958 }
1959 #endif
1960
1961 #if !defined(OPENSSL_IS_BORINGSSL)
1962 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1963   // Start listening on a local port
1964   ConnectTimeoutCallback acceptCallback;
1965   TestSSLServer server(&acceptCallback, true);
1966
1967   EventBase evb;
1968
1969   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1970   ConnCallback ccb;
1971   // Set a short timeout
1972   socket->connect(&ccb, server.getAddress(), 1);
1973
1974   evb.loop();
1975   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1976 }
1977 #endif
1978
1979 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
1980   // Start listening on a local port
1981   EmptyReadCallback readCallback;
1982   HandshakeCallback handshakeCallback(
1983       &readCallback, HandshakeCallback::EXPECT_ERROR);
1984   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
1985   TestSSLServer server(&acceptCallback, true);
1986
1987   EventBase evb;
1988
1989   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1990   ConnCallback ccb;
1991   socket->connect(&ccb, server.getAddress(), 100);
1992
1993   evb.loop();
1994   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1995   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
1996 }
1997
1998 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
1999   // Start listening on a local port
2000   EventBase evb;
2001
2002   // Hopefully nothing is listening on this address
2003   SocketAddress addr("127.0.0.1", 65535);
2004   auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
2005   ConnCallback ccb;
2006   socket->connect(&ccb, addr, 100);
2007
2008   evb.loop();
2009   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2010   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
2011 }
2012
2013 TEST(AsyncSSLSocketTest, TestPreReceivedData) {
2014   EventBase clientEventBase;
2015   EventBase serverEventBase;
2016   auto clientCtx = std::make_shared<SSLContext>();
2017   auto dfServerCtx = std::make_shared<SSLContext>();
2018   std::array<int, 2> fds;
2019   getfds(fds.data());
2020   getctx(clientCtx, dfServerCtx);
2021
2022   AsyncSSLSocket::UniquePtr clientSockPtr(
2023       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
2024   AsyncSSLSocket::UniquePtr serverSockPtr(
2025       new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
2026   auto clientSock = clientSockPtr.get();
2027   auto serverSock = serverSockPtr.get();
2028   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
2029
2030   // Steal some data from the server.
2031   clientEventBase.loopOnce();
2032   std::array<uint8_t, 10> buf;
2033   recv(fds[1], buf.data(), buf.size(), 0);
2034
2035   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
2036   SSLHandshakeServer server(std::move(serverSockPtr), true, true);
2037   while (!client.handshakeSuccess_ && !client.handshakeError_) {
2038     serverEventBase.loopOnce();
2039     clientEventBase.loopOnce();
2040   }
2041
2042   EXPECT_TRUE(client.handshakeSuccess_);
2043   EXPECT_TRUE(server.handshakeSuccess_);
2044   EXPECT_EQ(
2045       serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
2046 }
2047
2048 TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) {
2049   EventBase clientEventBase;
2050   EventBase serverEventBase;
2051   auto clientCtx = std::make_shared<SSLContext>();
2052   auto dfServerCtx = std::make_shared<SSLContext>();
2053   std::array<int, 2> fds;
2054   getfds(fds.data());
2055   getctx(clientCtx, dfServerCtx);
2056
2057   AsyncSSLSocket::UniquePtr clientSockPtr(
2058       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
2059   AsyncSocket::UniquePtr serverSockPtr(
2060       new AsyncSocket(&serverEventBase, fds[1]));
2061   auto clientSock = clientSockPtr.get();
2062   auto serverSock = serverSockPtr.get();
2063   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
2064
2065   // Steal some data from the server.
2066   clientEventBase.loopOnce();
2067   std::array<uint8_t, 10> buf;
2068   recv(fds[1], buf.data(), buf.size(), 0);
2069   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
2070   AsyncSSLSocket::UniquePtr serverSSLSockPtr(
2071       new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true));
2072   auto serverSSLSock = serverSSLSockPtr.get();
2073   SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true);
2074   while (!client.handshakeSuccess_ && !client.handshakeError_) {
2075     serverEventBase.loopOnce();
2076     clientEventBase.loopOnce();
2077   }
2078
2079   EXPECT_TRUE(client.handshakeSuccess_);
2080   EXPECT_TRUE(server.handshakeSuccess_);
2081   EXPECT_EQ(
2082       serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
2083 }
2084
2085 /**
2086  * Test overriding the flags passed to "sendmsg()" system call,
2087  * and verifying that write requests fail properly.
2088  */
2089 TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
2090   // Start listening on a local port
2091   SendMsgFlagsCallback msgCallback;
2092   ExpectWriteErrorCallback writeCallback(&msgCallback);
2093   ReadCallback readCallback(&writeCallback);
2094   HandshakeCallback handshakeCallback(&readCallback);
2095   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2096   TestSSLServer server(&acceptCallback);
2097
2098   // Set up SSL context.
2099   auto sslContext = std::make_shared<SSLContext>();
2100   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2101
2102   // connect
2103   auto socket =
2104       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2105   socket->open();
2106
2107   // Setting flags to "-1" to trigger "Invalid argument" error
2108   // on attempt to use this flags in sendmsg() system call.
2109   msgCallback.resetFlags(-1);
2110
2111   // write()
2112   std::vector<uint8_t> buf(128, 'a');
2113   ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
2114
2115   // close()
2116   socket->close();
2117
2118   cerr << "SendMsgParamsCallback test completed" << endl;
2119 }
2120
2121 #ifdef MSG_ERRQUEUE
2122 /**
2123  * Test connecting to, writing to, reading from, and closing the
2124  * connection to the SSL server.
2125  */
2126 TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
2127   // This test requires Linux kernel v4.6 or later
2128   struct utsname s_uname;
2129   memset(&s_uname, 0, sizeof(s_uname));
2130   ASSERT_EQ(uname(&s_uname), 0);
2131   int major, minor;
2132   folly::StringPiece extra;
2133   if (folly::split<false>(
2134           '.', std::string(s_uname.release) + ".", major, minor, extra)) {
2135     if (major < 4 || (major == 4 && minor < 6)) {
2136       LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
2137                 << "kernel ver. " << s_uname.release << " detected).";
2138       return;
2139     }
2140   }
2141
2142   // Start listening on a local port
2143   SendMsgDataCallback msgCallback;
2144   WriteCheckTimestampCallback writeCallback(&msgCallback);
2145   ReadCallback readCallback(&writeCallback);
2146   HandshakeCallback handshakeCallback(&readCallback);
2147   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2148   TestSSLServer server(&acceptCallback);
2149
2150   // Set up SSL context.
2151   auto sslContext = std::make_shared<SSLContext>();
2152   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2153
2154   // connect
2155   auto socket =
2156       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2157   socket->open();
2158
2159   // Adding MSG_EOR flag to the message flags - it'll trigger
2160   // timestamp generation for the last byte of the message.
2161   msgCallback.resetFlags(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_EOR);
2162
2163   // Init ancillary data buffer to trigger timestamp notification
2164   union {
2165     uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
2166     struct cmsghdr cmsg;
2167   } u;
2168   u.cmsg.cmsg_level = SOL_SOCKET;
2169   u.cmsg.cmsg_type = SO_TIMESTAMPING;
2170   u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
2171   uint32_t flags = SOF_TIMESTAMPING_TX_SCHED | SOF_TIMESTAMPING_TX_SOFTWARE |
2172       SOF_TIMESTAMPING_TX_ACK;
2173   memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
2174   std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
2175   memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
2176   msgCallback.resetData(std::move(ctrl));
2177
2178   // write()
2179   std::vector<uint8_t> buf(128, 'a');
2180   socket->write(buf.data(), buf.size());
2181
2182   // read()
2183   std::vector<uint8_t> readbuf(buf.size());
2184   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2185   EXPECT_EQ(bytesRead, buf.size());
2186   EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
2187
2188   writeCallback.checkForTimestampNotifications();
2189
2190   // close()
2191   socket->close();
2192
2193   cerr << "SendMsgDataCallback test completed" << endl;
2194 }
2195 #endif // MSG_ERRQUEUE
2196
2197 #endif
2198
2199 } // namespace
2200
2201 #ifdef SIGPIPE
2202 ///////////////////////////////////////////////////////////////////////////
2203 // init_unit_test_suite
2204 ///////////////////////////////////////////////////////////////////////////
2205 namespace {
2206 struct Initializer {
2207   Initializer() {
2208     signal(SIGPIPE, SIG_IGN);
2209   }
2210 };
2211 Initializer initializer;
2212 } // namespace
2213 #endif