Test server-side getPeerCert().
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <folly/SocketAddress.h>
19 #include <folly/io/Cursor.h>
20 #include <folly/io/async/AsyncSSLSocket.h>
21 #include <folly/io/async/EventBase.h>
22 #include <folly/portability/GMock.h>
23 #include <folly/portability/GTest.h>
24 #include <folly/portability/OpenSSL.h>
25 #include <folly/portability/Sockets.h>
26 #include <folly/portability/Unistd.h>
27
28 #include <folly/io/async/test/BlockingSocket.h>
29
30 #include <fcntl.h>
31 #include <signal.h>
32 #include <sys/types.h>
33
34 #include <fstream>
35 #include <iostream>
36 #include <list>
37 #include <set>
38 #include <thread>
39
40 #ifdef MSG_ERRQUEUE
41 #include <sys/utsname.h>
42 #endif
43
44 using std::string;
45 using std::vector;
46 using std::min;
47 using std::cerr;
48 using std::endl;
49 using std::list;
50
51 using namespace testing;
52
53 namespace folly {
54 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
55 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
56 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
57
58 constexpr size_t SSLClient::kMaxReadBufferSz;
59 constexpr size_t SSLClient::kMaxReadsPerEvent;
60
61 void getfds(int fds[2]) {
62   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
63     FAIL() << "failed to create socketpair: " << strerror(errno);
64   }
65   for (int idx = 0; idx < 2; ++idx) {
66     int flags = fcntl(fds[idx], F_GETFL, 0);
67     if (flags == -1) {
68       FAIL() << "failed to get flags for socket " << idx << ": "
69              << strerror(errno);
70     }
71     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
72       FAIL() << "failed to put socket " << idx
73              << " in non-blocking mode: " << strerror(errno);
74     }
75   }
76 }
77
78 void getctx(
79     std::shared_ptr<folly::SSLContext> clientCtx,
80     std::shared_ptr<folly::SSLContext> serverCtx) {
81   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
82
83   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
84   serverCtx->loadCertificate(kTestCert);
85   serverCtx->loadPrivateKey(kTestKey);
86 }
87
88 void sslsocketpair(
89     EventBase* eventBase,
90     AsyncSSLSocket::UniquePtr* clientSock,
91     AsyncSSLSocket::UniquePtr* serverSock) {
92   auto clientCtx = std::make_shared<folly::SSLContext>();
93   auto serverCtx = std::make_shared<folly::SSLContext>();
94   int fds[2];
95   getfds(fds);
96   getctx(clientCtx, serverCtx);
97   clientSock->reset(new AsyncSSLSocket(clientCtx, eventBase, fds[0], false));
98   serverSock->reset(new AsyncSSLSocket(serverCtx, eventBase, fds[1], true));
99
100   // (*clientSock)->setSendTimeout(100);
101   // (*serverSock)->setSendTimeout(100);
102 }
103
104 // client protocol filters
105 bool clientProtoFilterPickPony(
106     unsigned char** client,
107     unsigned int* client_len,
108     const unsigned char*,
109     unsigned int) {
110   // the protocol string in length prefixed byte string. the
111   // length byte is not included in the length
112   static unsigned char p[7] = {6, 'p', 'o', 'n', 'i', 'e', 's'};
113   *client = p;
114   *client_len = 7;
115   return true;
116 }
117
118 bool clientProtoFilterPickNone(
119     unsigned char**,
120     unsigned int*,
121     const unsigned char*,
122     unsigned int) {
123   return false;
124 }
125
126 std::string getFileAsBuf(const char* fileName) {
127   std::string buffer;
128   folly::readFile(fileName, buffer);
129   return buffer;
130 }
131
132 std::string getCommonName(X509* cert) {
133   X509_NAME* subject = X509_get_subject_name(cert);
134   std::string cn;
135   cn.resize(ub_common_name);
136   X509_NAME_get_text_by_NID(
137       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
138   return cn;
139 }
140
141 /**
142  * Test connecting to, writing to, reading from, and closing the
143  * connection to the SSL server.
144  */
145 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
146   // Start listening on a local port
147   WriteCallbackBase writeCallback;
148   ReadCallback readCallback(&writeCallback);
149   HandshakeCallback handshakeCallback(&readCallback);
150   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
151   TestSSLServer server(&acceptCallback);
152
153   // Set up SSL context.
154   std::shared_ptr<SSLContext> sslContext(new SSLContext());
155   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
156   // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
157   // sslContext->authenticate(true, false);
158
159   // connect
160   auto socket =
161       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
162   socket->open(std::chrono::milliseconds(10000));
163
164   // write()
165   uint8_t buf[128];
166   memset(buf, 'a', sizeof(buf));
167   socket->write(buf, sizeof(buf));
168
169   // read()
170   uint8_t readbuf[128];
171   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
172   EXPECT_EQ(bytesRead, 128);
173   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
174
175   // close()
176   socket->close();
177
178   cerr << "ConnectWriteReadClose test completed" << endl;
179   EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
180 }
181
182 /**
183  * Test reading after server close.
184  */
185 TEST(AsyncSSLSocketTest, ReadAfterClose) {
186   // Start listening on a local port
187   WriteCallbackBase writeCallback;
188   ReadEOFCallback readCallback(&writeCallback);
189   HandshakeCallback handshakeCallback(&readCallback);
190   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
191   auto server = std::make_unique<TestSSLServer>(&acceptCallback);
192
193   // Set up SSL context.
194   auto sslContext = std::make_shared<SSLContext>();
195   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
196
197   auto socket =
198       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
199   socket->open();
200
201   // This should trigger an EOF on the client.
202   auto evb = handshakeCallback.getSocket()->getEventBase();
203   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
204   std::array<uint8_t, 128> readbuf;
205   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
206   EXPECT_EQ(0, bytesRead);
207 }
208
209 /**
210  * Test bad renegotiation
211  */
212 #if !defined(OPENSSL_IS_BORINGSSL)
213 TEST(AsyncSSLSocketTest, Renegotiate) {
214   EventBase eventBase;
215   auto clientCtx = std::make_shared<SSLContext>();
216   auto dfServerCtx = std::make_shared<SSLContext>();
217   std::array<int, 2> fds;
218   getfds(fds.data());
219   getctx(clientCtx, dfServerCtx);
220
221   AsyncSSLSocket::UniquePtr clientSock(
222       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
223   AsyncSSLSocket::UniquePtr serverSock(
224       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
225   SSLHandshakeClient client(std::move(clientSock), true, true);
226   RenegotiatingServer server(std::move(serverSock));
227
228   while (!client.handshakeSuccess_ && !client.handshakeError_) {
229     eventBase.loopOnce();
230   }
231
232   ASSERT_TRUE(client.handshakeSuccess_);
233
234   auto sslSock = std::move(client).moveSocket();
235   sslSock->detachEventBase();
236   // This is nasty, however we don't want to add support for
237   // renegotiation in AsyncSSLSocket.
238   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
239
240   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
241
242   std::thread t([&]() { eventBase.loopForever(); });
243
244   // Trigger the renegotiation.
245   std::array<uint8_t, 128> buf;
246   memset(buf.data(), 'a', buf.size());
247   try {
248     socket->write(buf.data(), buf.size());
249   } catch (AsyncSocketException& e) {
250     LOG(INFO) << "client got error " << e.what();
251   }
252   eventBase.terminateLoopSoon();
253   t.join();
254
255   eventBase.loop();
256   ASSERT_TRUE(server.renegotiationError_);
257 }
258 #endif
259
260 /**
261  * Negative test for handshakeError().
262  */
263 TEST(AsyncSSLSocketTest, HandshakeError) {
264   // Start listening on a local port
265   WriteCallbackBase writeCallback;
266   WriteErrorCallback readCallback(&writeCallback);
267   HandshakeCallback handshakeCallback(&readCallback);
268   HandshakeErrorCallback acceptCallback(&handshakeCallback);
269   TestSSLServer server(&acceptCallback);
270
271   // Set up SSL context.
272   std::shared_ptr<SSLContext> sslContext(new SSLContext());
273   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
274
275   // connect
276   auto socket =
277       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
278   // read()
279   bool ex = false;
280   try {
281     socket->open();
282
283     uint8_t readbuf[128];
284     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
285     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
286   } catch (AsyncSocketException&) {
287     ex = true;
288   }
289   EXPECT_TRUE(ex);
290
291   // close()
292   socket->close();
293   cerr << "HandshakeError test completed" << endl;
294 }
295
296 /**
297  * Negative test for readError().
298  */
299 TEST(AsyncSSLSocketTest, ReadError) {
300   // Start listening on a local port
301   WriteCallbackBase writeCallback;
302   ReadErrorCallback readCallback(&writeCallback);
303   HandshakeCallback handshakeCallback(&readCallback);
304   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
305   TestSSLServer server(&acceptCallback);
306
307   // Set up SSL context.
308   std::shared_ptr<SSLContext> sslContext(new SSLContext());
309   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
310
311   // connect
312   auto socket =
313       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
314   socket->open();
315
316   // write something to trigger ssl handshake
317   uint8_t buf[128];
318   memset(buf, 'a', sizeof(buf));
319   socket->write(buf, sizeof(buf));
320
321   socket->close();
322   cerr << "ReadError test completed" << endl;
323 }
324
325 /**
326  * Negative test for writeError().
327  */
328 TEST(AsyncSSLSocketTest, WriteError) {
329   // Start listening on a local port
330   WriteCallbackBase writeCallback;
331   WriteErrorCallback readCallback(&writeCallback);
332   HandshakeCallback handshakeCallback(&readCallback);
333   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
334   TestSSLServer server(&acceptCallback);
335
336   // Set up SSL context.
337   std::shared_ptr<SSLContext> sslContext(new SSLContext());
338   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
339
340   // connect
341   auto socket =
342       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
343   socket->open();
344
345   // write something to trigger ssl handshake
346   uint8_t buf[128];
347   memset(buf, 'a', sizeof(buf));
348   socket->write(buf, sizeof(buf));
349
350   socket->close();
351   cerr << "WriteError test completed" << endl;
352 }
353
354 /**
355  * Test a socket with TCP_NODELAY unset.
356  */
357 TEST(AsyncSSLSocketTest, SocketWithDelay) {
358   // Start listening on a local port
359   WriteCallbackBase writeCallback;
360   ReadCallback readCallback(&writeCallback);
361   HandshakeCallback handshakeCallback(&readCallback);
362   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
363   TestSSLServer server(&acceptCallback);
364
365   // Set up SSL context.
366   std::shared_ptr<SSLContext> sslContext(new SSLContext());
367   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
368
369   // connect
370   auto socket =
371       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
372   socket->open();
373
374   // write()
375   uint8_t buf[128];
376   memset(buf, 'a', sizeof(buf));
377   socket->write(buf, sizeof(buf));
378
379   // read()
380   uint8_t readbuf[128];
381   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
382   EXPECT_EQ(bytesRead, 128);
383   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
384
385   // close()
386   socket->close();
387
388   cerr << "SocketWithDelay test completed" << endl;
389 }
390
391 using NextProtocolTypePair =
392     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
393
394 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
395   // For matching protos
396  public:
397   void SetUp() override {
398     getctx(clientCtx, serverCtx);
399   }
400
401   void connect(bool unset = false) {
402     getfds(fds);
403
404     if (unset) {
405       // unsetting NPN for any of [client, server] is enough to make NPN not
406       // work
407       clientCtx->unsetNextProtocols();
408     }
409
410     AsyncSSLSocket::UniquePtr clientSock(
411         new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
412     AsyncSSLSocket::UniquePtr serverSock(
413         new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
414     client = std::make_unique<NpnClient>(std::move(clientSock));
415     server = std::make_unique<NpnServer>(std::move(serverSock));
416
417     eventBase.loop();
418   }
419
420   void expectProtocol(const std::string& proto) {
421     expectHandshakeSuccess();
422     EXPECT_NE(client->nextProtoLength, 0);
423     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
424     EXPECT_EQ(
425         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
426         0);
427     string selected((const char*)client->nextProto, client->nextProtoLength);
428     EXPECT_EQ(proto, selected);
429   }
430
431   void expectNoProtocol() {
432     expectHandshakeSuccess();
433     EXPECT_EQ(client->nextProtoLength, 0);
434     EXPECT_EQ(server->nextProtoLength, 0);
435     EXPECT_EQ(client->nextProto, nullptr);
436     EXPECT_EQ(server->nextProto, nullptr);
437   }
438
439   void expectProtocolType() {
440     expectHandshakeSuccess();
441     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
442         GetParam().second == SSLContext::NextProtocolType::ANY) {
443       EXPECT_EQ(client->protocolType, server->protocolType);
444     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
445                GetParam().second == SSLContext::NextProtocolType::ANY) {
446       // Well not much we can say
447     } else {
448       expectProtocolType(GetParam());
449     }
450   }
451
452   void expectProtocolType(NextProtocolTypePair expected) {
453     expectHandshakeSuccess();
454     EXPECT_EQ(client->protocolType, expected.first);
455     EXPECT_EQ(server->protocolType, expected.second);
456   }
457
458   void expectHandshakeSuccess() {
459     EXPECT_FALSE(client->except.hasValue())
460         << "client handshake error: " << client->except->what();
461     EXPECT_FALSE(server->except.hasValue())
462         << "server handshake error: " << server->except->what();
463   }
464
465   void expectHandshakeError() {
466     EXPECT_TRUE(client->except.hasValue())
467         << "Expected client handshake error!";
468     EXPECT_TRUE(server->except.hasValue())
469         << "Expected server handshake error!";
470   }
471
472   EventBase eventBase;
473   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
474   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
475   int fds[2];
476   std::unique_ptr<NpnClient> client;
477   std::unique_ptr<NpnServer> server;
478 };
479
480 class NextProtocolTLSExtTest : public NextProtocolTest {
481   // For extended TLS protos
482 };
483
484 class NextProtocolNPNOnlyTest : public NextProtocolTest {
485   // For mismatching protos
486 };
487
488 class NextProtocolMismatchTest : public NextProtocolTest {
489   // For mismatching protos
490 };
491
492 TEST_P(NextProtocolTest, NpnTestOverlap) {
493   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
494   serverCtx->setAdvertisedNextProtocols(
495       {"foo", "bar", "baz"}, GetParam().second);
496
497   connect();
498
499   expectProtocol("baz");
500   expectProtocolType();
501 }
502
503 TEST_P(NextProtocolTest, NpnTestUnset) {
504   // Identical to above test, except that we want unset NPN before
505   // looping.
506   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
507   serverCtx->setAdvertisedNextProtocols(
508       {"foo", "bar", "baz"}, GetParam().second);
509
510   connect(true /* unset */);
511
512   // if alpn negotiation fails, type will appear as npn
513   expectNoProtocol();
514   EXPECT_EQ(client->protocolType, server->protocolType);
515 }
516
517 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
518   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
519   serverCtx->setAdvertisedNextProtocols(
520       {"foo", "bar", "baz"}, GetParam().second);
521
522   connect();
523
524   expectNoProtocol();
525   expectProtocolType(
526       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
527 }
528
529 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
530 // will fail on 1.0.2 before that.
531 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
532   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
533   serverCtx->setAdvertisedNextProtocols(
534       {"foo", "bar", "baz"}, GetParam().second);
535   connect();
536
537   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
538       GetParam().second == SSLContext::NextProtocolType::ALPN) {
539     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
540     // mismatch should result in a fatal alert, but this is the current behavior
541     // on all OpenSSL versions/variants, and we want to know if it changes.
542     expectNoProtocol();
543   }
544 #if FOLLY_OPENSSL_IS_110 || defined(OPENSSL_IS_BORINGSSL)
545   else if (
546       GetParam().first == SSLContext::NextProtocolType::ANY &&
547       GetParam().second == SSLContext::NextProtocolType::ANY) {
548 #if FOLLY_OPENSSL_IS_110
549     // OpenSSL 1.1.0 sends a fatal alert on mismatch, which is probavbly the
550     // correct behavior per RFC7301
551     expectHandshakeError();
552 #else
553     // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
554     // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
555     // NPN *and* ALPN
556     expectNoProtocol();
557 #endif
558   }
559 #endif
560   else {
561     expectProtocol("blub");
562     expectProtocolType(
563         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
564   }
565 }
566
567 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
568   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
569   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
570   serverCtx->setAdvertisedNextProtocols(
571       {"foo", "bar", "baz"}, GetParam().second);
572
573   connect();
574
575   expectProtocol("ponies");
576   expectProtocolType();
577 }
578
579 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
580   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
581   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
582   serverCtx->setAdvertisedNextProtocols(
583       {"foo", "bar", "baz"}, GetParam().second);
584
585   connect();
586
587   expectProtocol("blub");
588   expectProtocolType();
589 }
590
591 TEST_P(NextProtocolTest, RandomizedNpnTest) {
592   // Probability that this test will fail is 2^-64, which could be considered
593   // as negligible.
594   const int kTries = 64;
595
596   clientCtx->setAdvertisedNextProtocols(
597       {"foo", "bar", "baz"}, GetParam().first);
598   serverCtx->setRandomizedAdvertisedNextProtocols(
599       {{1, {"foo"}}, {1, {"bar"}}}, GetParam().second);
600
601   std::set<string> selectedProtocols;
602   for (int i = 0; i < kTries; ++i) {
603     connect();
604
605     EXPECT_NE(client->nextProtoLength, 0);
606     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
607     EXPECT_EQ(
608         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
609         0);
610     string selected((const char*)client->nextProto, client->nextProtoLength);
611     selectedProtocols.insert(selected);
612     expectProtocolType();
613   }
614   EXPECT_EQ(selectedProtocols.size(), 2);
615 }
616
617 INSTANTIATE_TEST_CASE_P(
618     AsyncSSLSocketTest,
619     NextProtocolTest,
620     ::testing::Values(
621         NextProtocolTypePair(
622             SSLContext::NextProtocolType::NPN,
623             SSLContext::NextProtocolType::NPN),
624         NextProtocolTypePair(
625             SSLContext::NextProtocolType::NPN,
626             SSLContext::NextProtocolType::ANY),
627         NextProtocolTypePair(
628             SSLContext::NextProtocolType::ANY,
629             SSLContext::NextProtocolType::ANY)));
630
631 #if FOLLY_OPENSSL_HAS_ALPN
632 INSTANTIATE_TEST_CASE_P(
633     AsyncSSLSocketTest,
634     NextProtocolTLSExtTest,
635     ::testing::Values(
636         NextProtocolTypePair(
637             SSLContext::NextProtocolType::ALPN,
638             SSLContext::NextProtocolType::ALPN),
639         NextProtocolTypePair(
640             SSLContext::NextProtocolType::ALPN,
641             SSLContext::NextProtocolType::ANY),
642         NextProtocolTypePair(
643             SSLContext::NextProtocolType::ANY,
644             SSLContext::NextProtocolType::ALPN)));
645 #endif
646
647 INSTANTIATE_TEST_CASE_P(
648     AsyncSSLSocketTest,
649     NextProtocolNPNOnlyTest,
650     ::testing::Values(NextProtocolTypePair(
651         SSLContext::NextProtocolType::NPN,
652         SSLContext::NextProtocolType::NPN)));
653
654 #if FOLLY_OPENSSL_HAS_ALPN
655 INSTANTIATE_TEST_CASE_P(
656     AsyncSSLSocketTest,
657     NextProtocolMismatchTest,
658     ::testing::Values(
659         NextProtocolTypePair(
660             SSLContext::NextProtocolType::NPN,
661             SSLContext::NextProtocolType::ALPN),
662         NextProtocolTypePair(
663             SSLContext::NextProtocolType::ALPN,
664             SSLContext::NextProtocolType::NPN)));
665 #endif
666
667 #ifndef OPENSSL_NO_TLSEXT
668 /**
669  * 1. Client sends TLSEXT_HOSTNAME in client hello.
670  * 2. Server found a match SSL_CTX and use this SSL_CTX to
671  *    continue the SSL handshake.
672  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
673  */
674 TEST(AsyncSSLSocketTest, SNITestMatch) {
675   EventBase eventBase;
676   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
677   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
678   // Use the same SSLContext to continue the handshake after
679   // tlsext_hostname match.
680   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
681   const std::string serverName("xyz.newdev.facebook.com");
682   int fds[2];
683   getfds(fds);
684   getctx(clientCtx, dfServerCtx);
685
686   AsyncSSLSocket::UniquePtr clientSock(
687       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
688   AsyncSSLSocket::UniquePtr serverSock(
689       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
690   SNIClient client(std::move(clientSock));
691   SNIServer server(
692       std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
693
694   eventBase.loop();
695
696   EXPECT_TRUE(client.serverNameMatch);
697   EXPECT_TRUE(server.serverNameMatch);
698 }
699
700 /**
701  * 1. Client sends TLSEXT_HOSTNAME in client hello.
702  * 2. Server cannot find a matching SSL_CTX and continue to use
703  *    the current SSL_CTX to do the handshake.
704  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
705  */
706 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
707   EventBase eventBase;
708   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
709   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
710   // Use the same SSLContext to continue the handshake after
711   // tlsext_hostname match.
712   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
713   const std::string clientRequestingServerName("foo.com");
714   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
715
716   int fds[2];
717   getfds(fds);
718   getctx(clientCtx, dfServerCtx);
719
720   AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
721       clientCtx, &eventBase, fds[0], clientRequestingServerName));
722   AsyncSSLSocket::UniquePtr serverSock(
723       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
724   SNIClient client(std::move(clientSock));
725   SNIServer server(
726       std::move(serverSock),
727       dfServerCtx,
728       hskServerCtx,
729       serverExpectedServerName);
730
731   eventBase.loop();
732
733   EXPECT_TRUE(!client.serverNameMatch);
734   EXPECT_TRUE(!server.serverNameMatch);
735 }
736 /**
737  * 1. Client sends TLSEXT_HOSTNAME in client hello.
738  * 2. We then change the serverName.
739  * 3. We expect that we get 'false' as the result for serNameMatch.
740  */
741
742 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
743   EventBase eventBase;
744   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
745   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
746   // Use the same SSLContext to continue the handshake after
747   // tlsext_hostname match.
748   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
749   const std::string serverName("xyz.newdev.facebook.com");
750   int fds[2];
751   getfds(fds);
752   getctx(clientCtx, dfServerCtx);
753
754   AsyncSSLSocket::UniquePtr clientSock(
755       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
756   // Change the server name
757   std::string newName("new.com");
758   clientSock->setServerName(newName);
759   AsyncSSLSocket::UniquePtr serverSock(
760       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
761   SNIClient client(std::move(clientSock));
762   SNIServer server(
763       std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
764
765   eventBase.loop();
766
767   EXPECT_TRUE(!client.serverNameMatch);
768 }
769
770 /**
771  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
772  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
773  */
774 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
775   EventBase eventBase;
776   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
777   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
778   // Use the same SSLContext to continue the handshake after
779   // tlsext_hostname match.
780   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
781   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
782
783   int fds[2];
784   getfds(fds);
785   getctx(clientCtx, dfServerCtx);
786
787   AsyncSSLSocket::UniquePtr clientSock(
788       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
789   AsyncSSLSocket::UniquePtr serverSock(
790       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
791   SNIClient client(std::move(clientSock));
792   SNIServer server(
793       std::move(serverSock),
794       dfServerCtx,
795       hskServerCtx,
796       serverExpectedServerName);
797
798   eventBase.loop();
799
800   EXPECT_TRUE(!client.serverNameMatch);
801   EXPECT_TRUE(!server.serverNameMatch);
802 }
803
804 #endif
805 /**
806  * Test SSL client socket
807  */
808 TEST(AsyncSSLSocketTest, SSLClientTest) {
809   // Start listening on a local port
810   WriteCallbackBase writeCallback;
811   ReadCallback readCallback(&writeCallback);
812   HandshakeCallback handshakeCallback(&readCallback);
813   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
814   TestSSLServer server(&acceptCallback);
815
816   // Set up SSL client
817   EventBase eventBase;
818   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
819
820   client->connect();
821   EventBaseAborter eba(&eventBase, 3000);
822   eventBase.loop();
823
824   EXPECT_EQ(client->getMiss(), 1);
825   EXPECT_EQ(client->getHit(), 0);
826
827   cerr << "SSLClientTest test completed" << endl;
828 }
829
830 /**
831  * Test SSL client socket session re-use
832  */
833 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
834   // Start listening on a local port
835   WriteCallbackBase writeCallback;
836   ReadCallback readCallback(&writeCallback);
837   HandshakeCallback handshakeCallback(&readCallback);
838   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
839   TestSSLServer server(&acceptCallback);
840
841   // Set up SSL client
842   EventBase eventBase;
843   auto client =
844       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
845
846   client->connect();
847   EventBaseAborter eba(&eventBase, 3000);
848   eventBase.loop();
849
850   EXPECT_EQ(client->getMiss(), 1);
851   EXPECT_EQ(client->getHit(), 9);
852
853   cerr << "SSLClientTestReuse test completed" << endl;
854 }
855
856 /**
857  * Test SSL client socket timeout
858  */
859 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
860   // Start listening on a local port
861   EmptyReadCallback readCallback;
862   HandshakeCallback handshakeCallback(
863       &readCallback, HandshakeCallback::EXPECT_ERROR);
864   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
865   TestSSLServer server(&acceptCallback);
866
867   // Set up SSL client
868   EventBase eventBase;
869   auto client =
870       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
871   client->connect(true /* write before connect completes */);
872   EventBaseAborter eba(&eventBase, 3000);
873   eventBase.loop();
874
875   usleep(100000);
876   // This is checking that the connectError callback precedes any queued
877   // writeError callbacks.  This matches AsyncSocket's behavior
878   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
879   EXPECT_EQ(client->getErrors(), 1);
880   EXPECT_EQ(client->getMiss(), 0);
881   EXPECT_EQ(client->getHit(), 0);
882
883   cerr << "SSLClientTimeoutTest test completed" << endl;
884 }
885
886 // The next 3 tests need an FB-only extension, and will fail without it
887 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
888 /**
889  * Test SSL server async cache
890  */
891 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
892   // Start listening on a local port
893   WriteCallbackBase writeCallback;
894   ReadCallback readCallback(&writeCallback);
895   HandshakeCallback handshakeCallback(&readCallback);
896   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
897   TestSSLAsyncCacheServer server(&acceptCallback);
898
899   // Set up SSL client
900   EventBase eventBase;
901   auto client =
902       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
903
904   client->connect();
905   EventBaseAborter eba(&eventBase, 3000);
906   eventBase.loop();
907
908   EXPECT_EQ(server.getAsyncCallbacks(), 18);
909   EXPECT_EQ(server.getAsyncLookups(), 9);
910   EXPECT_EQ(client->getMiss(), 10);
911   EXPECT_EQ(client->getHit(), 0);
912
913   cerr << "SSLServerAsyncCacheTest test completed" << endl;
914 }
915
916 /**
917  * Test SSL server accept timeout with cache path
918  */
919 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
920   // Start listening on a local port
921   WriteCallbackBase writeCallback;
922   ReadCallback readCallback(&writeCallback);
923   HandshakeCallback handshakeCallback(&readCallback);
924   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
925   TestSSLAsyncCacheServer server(&acceptCallback);
926
927   // Set up SSL client
928   EventBase eventBase;
929   // only do a TCP connect
930   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
931   sock->connect(nullptr, server.getAddress());
932
933   EmptyReadCallback clientReadCallback;
934   clientReadCallback.tcpSocket_ = sock;
935   sock->setReadCB(&clientReadCallback);
936
937   EventBaseAborter eba(&eventBase, 3000);
938   eventBase.loop();
939
940   EXPECT_EQ(readCallback.state, STATE_WAITING);
941
942   cerr << "SSLServerTimeoutTest test completed" << endl;
943 }
944
945 /**
946  * Test SSL server accept timeout with cache path
947  */
948 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
949   // Start listening on a local port
950   WriteCallbackBase writeCallback;
951   ReadCallback readCallback(&writeCallback);
952   HandshakeCallback handshakeCallback(&readCallback);
953   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
954   TestSSLAsyncCacheServer server(&acceptCallback);
955
956   // Set up SSL client
957   EventBase eventBase;
958   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
959
960   client->connect();
961   EventBaseAborter eba(&eventBase, 3000);
962   eventBase.loop();
963
964   EXPECT_EQ(server.getAsyncCallbacks(), 1);
965   EXPECT_EQ(server.getAsyncLookups(), 1);
966   EXPECT_EQ(client->getErrors(), 1);
967   EXPECT_EQ(client->getMiss(), 1);
968   EXPECT_EQ(client->getHit(), 0);
969
970   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
971 }
972
973 /**
974  * Test SSL server accept timeout with cache path
975  */
976 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
977   // Start listening on a local port
978   WriteCallbackBase writeCallback;
979   ReadCallback readCallback(&writeCallback);
980   HandshakeCallback handshakeCallback(
981       &readCallback, HandshakeCallback::EXPECT_ERROR);
982   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
983   TestSSLAsyncCacheServer server(&acceptCallback, 500);
984
985   // Set up SSL client
986   EventBase eventBase;
987   auto client =
988       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
989
990   client->connect();
991   EventBaseAborter eba(&eventBase, 3000);
992   eventBase.loop();
993
994   server.getEventBase().runInEventBaseThread(
995       [&handshakeCallback] { handshakeCallback.closeSocket(); });
996   // give time for the cache lookup to come back and find it closed
997   handshakeCallback.waitForHandshake();
998
999   EXPECT_EQ(server.getAsyncCallbacks(), 1);
1000   EXPECT_EQ(server.getAsyncLookups(), 1);
1001   EXPECT_EQ(client->getErrors(), 1);
1002   EXPECT_EQ(client->getMiss(), 1);
1003   EXPECT_EQ(client->getHit(), 0);
1004
1005   cerr << "SSLServerCacheCloseTest test completed" << endl;
1006 }
1007 #endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
1008
1009 /**
1010  * Verify Client Ciphers obtained using SSL MSG Callback.
1011  */
1012 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1013   EventBase eventBase;
1014   auto clientCtx = std::make_shared<SSLContext>();
1015   auto serverCtx = std::make_shared<SSLContext>();
1016   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1017   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1018   serverCtx->loadPrivateKey(kTestKey);
1019   serverCtx->loadCertificate(kTestCert);
1020   serverCtx->loadTrustedCertificates(kTestCA);
1021   serverCtx->loadClientCAList(kTestCA);
1022
1023   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1024   clientCtx->ciphers("AES256-SHA:AES128-SHA");
1025   clientCtx->loadPrivateKey(kTestKey);
1026   clientCtx->loadCertificate(kTestCert);
1027   clientCtx->loadTrustedCertificates(kTestCA);
1028
1029   int fds[2];
1030   getfds(fds);
1031
1032   AsyncSSLSocket::UniquePtr clientSock(
1033       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1034   AsyncSSLSocket::UniquePtr serverSock(
1035       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1036
1037   SSLHandshakeClient client(std::move(clientSock), true, true);
1038   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1039
1040   eventBase.loop();
1041
1042 #if defined(OPENSSL_IS_BORINGSSL)
1043   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
1044 #else
1045   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
1046 #endif
1047   EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1048   EXPECT_TRUE(client.handshakeVerify_);
1049   EXPECT_TRUE(client.handshakeSuccess_);
1050   EXPECT_TRUE(!client.handshakeError_);
1051   EXPECT_TRUE(server.handshakeVerify_);
1052   EXPECT_TRUE(server.handshakeSuccess_);
1053   EXPECT_TRUE(!server.handshakeError_);
1054 }
1055
1056 /**
1057  * Verify that server is able to get client cert by getPeerCert() API.
1058  */
1059 TEST(AsyncSSLSocketTest, GetClientCertificate) {
1060   EventBase eventBase;
1061   auto clientCtx = std::make_shared<SSLContext>();
1062   auto serverCtx = std::make_shared<SSLContext>();
1063   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1064   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1065   serverCtx->loadPrivateKey(kTestKey);
1066   serverCtx->loadCertificate(kTestCert);
1067   serverCtx->loadTrustedCertificates(kClientTestCA);
1068   serverCtx->loadClientCAList(kClientTestCA);
1069
1070   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1071   clientCtx->ciphers("AES256-SHA:AES128-SHA");
1072   clientCtx->loadPrivateKey(kClientTestKey);
1073   clientCtx->loadCertificate(kClientTestCert);
1074   clientCtx->loadTrustedCertificates(kTestCA);
1075
1076   std::array<int, 2> fds;
1077   getfds(fds.data());
1078
1079   AsyncSSLSocket::UniquePtr clientSock(
1080       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1081   AsyncSSLSocket::UniquePtr serverSock(
1082       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1083
1084   SSLHandshakeClient client(std::move(clientSock), true, true);
1085   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1086
1087   eventBase.loop();
1088
1089   // Handshake should succeed.
1090   EXPECT_TRUE(client.handshakeSuccess_);
1091   EXPECT_TRUE(server.handshakeSuccess_);
1092
1093   // Reclaim the sockets from SSLHandshakeBase.
1094   auto cliSocket = std::move(client).moveSocket();
1095   auto srvSocket = std::move(server).moveSocket();
1096
1097   // Client cert retrieved from server side.
1098   folly::ssl::X509UniquePtr serverPeerCert = srvSocket->getPeerCert();
1099   CHECK(serverPeerCert);
1100
1101   // Client cert retrieved from client side.
1102   const X509* clientSelfCert = cliSocket->getSelfCert();
1103   CHECK(clientSelfCert);
1104
1105   // The two certs should be the same.
1106   EXPECT_EQ(0, X509_cmp(clientSelfCert, serverPeerCert.get()));
1107 }
1108
1109 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1110   EventBase eventBase;
1111   auto ctx = std::make_shared<SSLContext>();
1112
1113   int fds[2];
1114   getfds(fds);
1115
1116   int bufLen = 42;
1117   uint8_t majorVersion = 18;
1118   uint8_t minorVersion = 25;
1119
1120   // Create callback buf
1121   auto buf = IOBuf::create(bufLen);
1122   buf->append(bufLen);
1123   folly::io::RWPrivateCursor cursor(buf.get());
1124   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1125   cursor.write<uint16_t>(0);
1126   cursor.write<uint8_t>(38);
1127   cursor.write<uint8_t>(majorVersion);
1128   cursor.write<uint8_t>(minorVersion);
1129   cursor.skip(32);
1130   cursor.write<uint32_t>(0);
1131
1132   SSL* ssl = ctx->createSSL();
1133   SCOPE_EXIT {
1134     SSL_free(ssl);
1135   };
1136   AsyncSSLSocket::UniquePtr sock(
1137       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1138   sock->enableClientHelloParsing();
1139
1140   // Test client hello parsing in one packet
1141   AsyncSSLSocket::clientHelloParsingCallback(
1142       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1143   buf.reset();
1144
1145   auto parsedClientHello = sock->getClientHelloInfo();
1146   EXPECT_TRUE(parsedClientHello != nullptr);
1147   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1148   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1149 }
1150
1151 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1152   EventBase eventBase;
1153   auto ctx = std::make_shared<SSLContext>();
1154
1155   int fds[2];
1156   getfds(fds);
1157
1158   int bufLen = 42;
1159   uint8_t majorVersion = 18;
1160   uint8_t minorVersion = 25;
1161
1162   // Create callback buf
1163   auto buf = IOBuf::create(bufLen);
1164   buf->append(bufLen);
1165   folly::io::RWPrivateCursor cursor(buf.get());
1166   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1167   cursor.write<uint16_t>(0);
1168   cursor.write<uint8_t>(38);
1169   cursor.write<uint8_t>(majorVersion);
1170   cursor.write<uint8_t>(minorVersion);
1171   cursor.skip(32);
1172   cursor.write<uint32_t>(0);
1173
1174   SSL* ssl = ctx->createSSL();
1175   SCOPE_EXIT {
1176     SSL_free(ssl);
1177   };
1178   AsyncSSLSocket::UniquePtr sock(
1179       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1180   sock->enableClientHelloParsing();
1181
1182   // Test parsing with two packets with first packet size < 3
1183   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1184   AsyncSSLSocket::clientHelloParsingCallback(
1185       0,
1186       0,
1187       SSL3_RT_HANDSHAKE,
1188       bufCopy->data(),
1189       bufCopy->length(),
1190       ssl,
1191       sock.get());
1192   bufCopy.reset();
1193   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1194   AsyncSSLSocket::clientHelloParsingCallback(
1195       0,
1196       0,
1197       SSL3_RT_HANDSHAKE,
1198       bufCopy->data(),
1199       bufCopy->length(),
1200       ssl,
1201       sock.get());
1202   bufCopy.reset();
1203
1204   auto parsedClientHello = sock->getClientHelloInfo();
1205   EXPECT_TRUE(parsedClientHello != nullptr);
1206   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1207   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1208 }
1209
1210 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1211   EventBase eventBase;
1212   auto ctx = std::make_shared<SSLContext>();
1213
1214   int fds[2];
1215   getfds(fds);
1216
1217   int bufLen = 42;
1218   uint8_t majorVersion = 18;
1219   uint8_t minorVersion = 25;
1220
1221   // Create callback buf
1222   auto buf = IOBuf::create(bufLen);
1223   buf->append(bufLen);
1224   folly::io::RWPrivateCursor cursor(buf.get());
1225   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1226   cursor.write<uint16_t>(0);
1227   cursor.write<uint8_t>(38);
1228   cursor.write<uint8_t>(majorVersion);
1229   cursor.write<uint8_t>(minorVersion);
1230   cursor.skip(32);
1231   cursor.write<uint32_t>(0);
1232
1233   SSL* ssl = ctx->createSSL();
1234   SCOPE_EXIT {
1235     SSL_free(ssl);
1236   };
1237   AsyncSSLSocket::UniquePtr sock(
1238       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1239   sock->enableClientHelloParsing();
1240
1241   // Test parsing with multiple small packets
1242   for (uint64_t i = 0; i < buf->length(); i += 3) {
1243     auto bufCopy = folly::IOBuf::copyBuffer(
1244         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1245     AsyncSSLSocket::clientHelloParsingCallback(
1246         0,
1247         0,
1248         SSL3_RT_HANDSHAKE,
1249         bufCopy->data(),
1250         bufCopy->length(),
1251         ssl,
1252         sock.get());
1253     bufCopy.reset();
1254   }
1255
1256   auto parsedClientHello = sock->getClientHelloInfo();
1257   EXPECT_TRUE(parsedClientHello != nullptr);
1258   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1259   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1260 }
1261
1262 /**
1263  * Verify sucessful behavior of SSL certificate validation.
1264  */
1265 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1266   EventBase eventBase;
1267   auto clientCtx = std::make_shared<SSLContext>();
1268   auto dfServerCtx = std::make_shared<SSLContext>();
1269
1270   int fds[2];
1271   getfds(fds);
1272   getctx(clientCtx, dfServerCtx);
1273
1274   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1275   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1276
1277   AsyncSSLSocket::UniquePtr clientSock(
1278       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1279   AsyncSSLSocket::UniquePtr serverSock(
1280       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1281
1282   SSLHandshakeClient client(std::move(clientSock), true, true);
1283   clientCtx->loadTrustedCertificates(kTestCA);
1284
1285   SSLHandshakeServer server(std::move(serverSock), true, true);
1286
1287   eventBase.loop();
1288
1289   EXPECT_TRUE(client.handshakeVerify_);
1290   EXPECT_TRUE(client.handshakeSuccess_);
1291   EXPECT_TRUE(!client.handshakeError_);
1292   EXPECT_LE(0, client.handshakeTime.count());
1293   EXPECT_TRUE(!server.handshakeVerify_);
1294   EXPECT_TRUE(server.handshakeSuccess_);
1295   EXPECT_TRUE(!server.handshakeError_);
1296   EXPECT_LE(0, server.handshakeTime.count());
1297 }
1298
1299 /**
1300  * Verify that the client's verification callback is able to fail SSL
1301  * connection establishment.
1302  */
1303 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1304   EventBase eventBase;
1305   auto clientCtx = std::make_shared<SSLContext>();
1306   auto dfServerCtx = std::make_shared<SSLContext>();
1307
1308   int fds[2];
1309   getfds(fds);
1310   getctx(clientCtx, dfServerCtx);
1311
1312   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1313   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1314
1315   AsyncSSLSocket::UniquePtr clientSock(
1316       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1317   AsyncSSLSocket::UniquePtr serverSock(
1318       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1319
1320   SSLHandshakeClient client(std::move(clientSock), true, false);
1321   clientCtx->loadTrustedCertificates(kTestCA);
1322
1323   SSLHandshakeServer server(std::move(serverSock), true, true);
1324
1325   eventBase.loop();
1326
1327   EXPECT_TRUE(client.handshakeVerify_);
1328   EXPECT_TRUE(!client.handshakeSuccess_);
1329   EXPECT_TRUE(client.handshakeError_);
1330   EXPECT_LE(0, client.handshakeTime.count());
1331   EXPECT_TRUE(!server.handshakeVerify_);
1332   EXPECT_TRUE(!server.handshakeSuccess_);
1333   EXPECT_TRUE(server.handshakeError_);
1334   EXPECT_LE(0, server.handshakeTime.count());
1335 }
1336
1337 /**
1338  * Verify that the options in SSLContext can be overridden in
1339  * sslConnect/Accept.i.e specifying that no validation should be performed
1340  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1341  * the validation callback.
1342  */
1343 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1344   EventBase eventBase;
1345   auto clientCtx = std::make_shared<SSLContext>();
1346   auto dfServerCtx = std::make_shared<SSLContext>();
1347
1348   int fds[2];
1349   getfds(fds);
1350   getctx(clientCtx, dfServerCtx);
1351
1352   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1353   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1354
1355   AsyncSSLSocket::UniquePtr clientSock(
1356       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1357   AsyncSSLSocket::UniquePtr serverSock(
1358       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1359
1360   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1361   clientCtx->loadTrustedCertificates(kTestCA);
1362
1363   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1364
1365   eventBase.loop();
1366
1367   EXPECT_TRUE(!client.handshakeVerify_);
1368   EXPECT_TRUE(client.handshakeSuccess_);
1369   EXPECT_TRUE(!client.handshakeError_);
1370   EXPECT_LE(0, client.handshakeTime.count());
1371   EXPECT_TRUE(!server.handshakeVerify_);
1372   EXPECT_TRUE(server.handshakeSuccess_);
1373   EXPECT_TRUE(!server.handshakeError_);
1374   EXPECT_LE(0, server.handshakeTime.count());
1375 }
1376
1377 /**
1378  * Verify that the options in SSLContext can be overridden in
1379  * sslConnect/Accept. Enable verification even if context says otherwise.
1380  * Test requireClientCert with client cert
1381  */
1382 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1383   EventBase eventBase;
1384   auto clientCtx = std::make_shared<SSLContext>();
1385   auto serverCtx = std::make_shared<SSLContext>();
1386   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1387   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1388   serverCtx->loadPrivateKey(kTestKey);
1389   serverCtx->loadCertificate(kTestCert);
1390   serverCtx->loadTrustedCertificates(kTestCA);
1391   serverCtx->loadClientCAList(kTestCA);
1392
1393   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1394   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1395   clientCtx->loadPrivateKey(kTestKey);
1396   clientCtx->loadCertificate(kTestCert);
1397   clientCtx->loadTrustedCertificates(kTestCA);
1398
1399   int fds[2];
1400   getfds(fds);
1401
1402   AsyncSSLSocket::UniquePtr clientSock(
1403       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1404   AsyncSSLSocket::UniquePtr serverSock(
1405       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1406
1407   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1408   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1409
1410   eventBase.loop();
1411
1412   EXPECT_TRUE(client.handshakeVerify_);
1413   EXPECT_TRUE(client.handshakeSuccess_);
1414   EXPECT_FALSE(client.handshakeError_);
1415   EXPECT_LE(0, client.handshakeTime.count());
1416   EXPECT_TRUE(server.handshakeVerify_);
1417   EXPECT_TRUE(server.handshakeSuccess_);
1418   EXPECT_FALSE(server.handshakeError_);
1419   EXPECT_LE(0, server.handshakeTime.count());
1420 }
1421
1422 /**
1423  * Verify that the client's verification callback is able to override
1424  * the preverification failure and allow a successful connection.
1425  */
1426 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1427   EventBase eventBase;
1428   auto clientCtx = std::make_shared<SSLContext>();
1429   auto dfServerCtx = std::make_shared<SSLContext>();
1430
1431   int fds[2];
1432   getfds(fds);
1433   getctx(clientCtx, dfServerCtx);
1434
1435   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1436   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1437
1438   AsyncSSLSocket::UniquePtr clientSock(
1439       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1440   AsyncSSLSocket::UniquePtr serverSock(
1441       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1442
1443   SSLHandshakeClient client(std::move(clientSock), false, true);
1444   SSLHandshakeServer server(std::move(serverSock), true, true);
1445
1446   eventBase.loop();
1447
1448   EXPECT_TRUE(client.handshakeVerify_);
1449   EXPECT_TRUE(client.handshakeSuccess_);
1450   EXPECT_TRUE(!client.handshakeError_);
1451   EXPECT_LE(0, client.handshakeTime.count());
1452   EXPECT_TRUE(!server.handshakeVerify_);
1453   EXPECT_TRUE(server.handshakeSuccess_);
1454   EXPECT_TRUE(!server.handshakeError_);
1455   EXPECT_LE(0, server.handshakeTime.count());
1456 }
1457
1458 /**
1459  * Verify that specifying that no validation should be performed allows an
1460  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1461  * callback.
1462  */
1463 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1464   EventBase eventBase;
1465   auto clientCtx = std::make_shared<SSLContext>();
1466   auto dfServerCtx = std::make_shared<SSLContext>();
1467
1468   int fds[2];
1469   getfds(fds);
1470   getctx(clientCtx, dfServerCtx);
1471
1472   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1473   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1474
1475   AsyncSSLSocket::UniquePtr clientSock(
1476       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1477   AsyncSSLSocket::UniquePtr serverSock(
1478       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1479
1480   SSLHandshakeClient client(std::move(clientSock), false, false);
1481   SSLHandshakeServer server(std::move(serverSock), false, false);
1482
1483   eventBase.loop();
1484
1485   EXPECT_TRUE(!client.handshakeVerify_);
1486   EXPECT_TRUE(client.handshakeSuccess_);
1487   EXPECT_TRUE(!client.handshakeError_);
1488   EXPECT_LE(0, client.handshakeTime.count());
1489   EXPECT_TRUE(!server.handshakeVerify_);
1490   EXPECT_TRUE(server.handshakeSuccess_);
1491   EXPECT_TRUE(!server.handshakeError_);
1492   EXPECT_LE(0, server.handshakeTime.count());
1493 }
1494
1495 /**
1496  * Test requireClientCert with client cert
1497  */
1498 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1499   EventBase eventBase;
1500   auto clientCtx = std::make_shared<SSLContext>();
1501   auto serverCtx = std::make_shared<SSLContext>();
1502   serverCtx->setVerificationOption(
1503       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1504   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1505   serverCtx->loadPrivateKey(kTestKey);
1506   serverCtx->loadCertificate(kTestCert);
1507   serverCtx->loadTrustedCertificates(kTestCA);
1508   serverCtx->loadClientCAList(kTestCA);
1509
1510   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1511   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1512   clientCtx->loadPrivateKey(kTestKey);
1513   clientCtx->loadCertificate(kTestCert);
1514   clientCtx->loadTrustedCertificates(kTestCA);
1515
1516   int fds[2];
1517   getfds(fds);
1518
1519   AsyncSSLSocket::UniquePtr clientSock(
1520       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1521   AsyncSSLSocket::UniquePtr serverSock(
1522       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1523
1524   SSLHandshakeClient client(std::move(clientSock), true, true);
1525   SSLHandshakeServer server(std::move(serverSock), true, true);
1526
1527   eventBase.loop();
1528
1529   EXPECT_TRUE(client.handshakeVerify_);
1530   EXPECT_TRUE(client.handshakeSuccess_);
1531   EXPECT_FALSE(client.handshakeError_);
1532   EXPECT_LE(0, client.handshakeTime.count());
1533   EXPECT_TRUE(server.handshakeVerify_);
1534   EXPECT_TRUE(server.handshakeSuccess_);
1535   EXPECT_FALSE(server.handshakeError_);
1536   EXPECT_LE(0, server.handshakeTime.count());
1537 }
1538
1539 /**
1540  * Test requireClientCert with no client cert
1541  */
1542 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1543   EventBase eventBase;
1544   auto clientCtx = std::make_shared<SSLContext>();
1545   auto serverCtx = std::make_shared<SSLContext>();
1546   serverCtx->setVerificationOption(
1547       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1548   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1549   serverCtx->loadPrivateKey(kTestKey);
1550   serverCtx->loadCertificate(kTestCert);
1551   serverCtx->loadTrustedCertificates(kTestCA);
1552   serverCtx->loadClientCAList(kTestCA);
1553   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1554   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1555
1556   int fds[2];
1557   getfds(fds);
1558
1559   AsyncSSLSocket::UniquePtr clientSock(
1560       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1561   AsyncSSLSocket::UniquePtr serverSock(
1562       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1563
1564   SSLHandshakeClient client(std::move(clientSock), false, false);
1565   SSLHandshakeServer server(std::move(serverSock), false, false);
1566
1567   eventBase.loop();
1568
1569   EXPECT_FALSE(server.handshakeVerify_);
1570   EXPECT_FALSE(server.handshakeSuccess_);
1571   EXPECT_TRUE(server.handshakeError_);
1572   EXPECT_LE(0, client.handshakeTime.count());
1573   EXPECT_LE(0, server.handshakeTime.count());
1574 }
1575
1576 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1577   auto cert = getFileAsBuf(kTestCert);
1578   auto key = getFileAsBuf(kTestKey);
1579
1580   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1581   BIO_write(certBio.get(), cert.data(), cert.size());
1582   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1583   BIO_write(keyBio.get(), key.data(), key.size());
1584
1585   // Create SSL structs from buffers to get properties
1586   ssl::X509UniquePtr certStruct(
1587       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1588   ssl::EvpPkeyUniquePtr keyStruct(
1589       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1590   certBio = nullptr;
1591   keyBio = nullptr;
1592
1593   auto origCommonName = getCommonName(certStruct.get());
1594   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1595   certStruct = nullptr;
1596   keyStruct = nullptr;
1597
1598   auto ctx = std::make_shared<SSLContext>();
1599   ctx->loadPrivateKeyFromBufferPEM(key);
1600   ctx->loadCertificateFromBufferPEM(cert);
1601   ctx->loadTrustedCertificates(kTestCA);
1602
1603   ssl::SSLUniquePtr ssl(ctx->createSSL());
1604
1605   auto newCert = SSL_get_certificate(ssl.get());
1606   auto newKey = SSL_get_privatekey(ssl.get());
1607
1608   // Get properties from SSL struct
1609   auto newCommonName = getCommonName(newCert);
1610   auto newKeySize = EVP_PKEY_bits(newKey);
1611
1612   // Check that the key and cert have the expected properties
1613   EXPECT_EQ(origCommonName, newCommonName);
1614   EXPECT_EQ(origKeySize, newKeySize);
1615 }
1616
1617 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1618   EventBase eb;
1619
1620   // Set up SSL context.
1621   auto sslContext = std::make_shared<SSLContext>();
1622   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1623
1624   // create SSL socket
1625   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1626
1627   EXPECT_EQ(1500, socket->getMinWriteSize());
1628
1629   socket->setMinWriteSize(0);
1630   EXPECT_EQ(0, socket->getMinWriteSize());
1631   socket->setMinWriteSize(50000);
1632   EXPECT_EQ(50000, socket->getMinWriteSize());
1633 }
1634
1635 class ReadCallbackTerminator : public ReadCallback {
1636  public:
1637   ReadCallbackTerminator(EventBase* base, WriteCallbackBase* wcb)
1638       : ReadCallback(wcb), base_(base) {}
1639
1640   // Do not write data back, terminate the loop.
1641   void readDataAvailable(size_t len) noexcept override {
1642     std::cerr << "readDataAvailable, len " << len << std::endl;
1643
1644     currentBuffer.length = len;
1645
1646     buffers.push_back(currentBuffer);
1647     currentBuffer.reset();
1648     state = STATE_SUCCEEDED;
1649
1650     socket_->setReadCB(nullptr);
1651     base_->terminateLoopSoon();
1652   }
1653
1654  private:
1655   EventBase* base_;
1656 };
1657
1658 /**
1659  * Test a full unencrypted codepath
1660  */
1661 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1662   EventBase base;
1663
1664   auto clientCtx = std::make_shared<folly::SSLContext>();
1665   auto serverCtx = std::make_shared<folly::SSLContext>();
1666   int fds[2];
1667   getfds(fds);
1668   getctx(clientCtx, serverCtx);
1669   auto client =
1670       AsyncSSLSocket::newSocket(clientCtx, &base, fds[0], false, true);
1671   auto server = AsyncSSLSocket::newSocket(serverCtx, &base, fds[1], true, true);
1672
1673   ReadCallbackTerminator readCallback(&base, nullptr);
1674   server->setReadCB(&readCallback);
1675   readCallback.setSocket(server);
1676
1677   uint8_t buf[128];
1678   memset(buf, 'a', sizeof(buf));
1679   client->write(nullptr, buf, sizeof(buf));
1680
1681   // Check that bytes are unencrypted
1682   char c;
1683   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1684   EXPECT_EQ('a', c);
1685
1686   EventBaseAborter eba(&base, 3000);
1687   base.loop();
1688
1689   EXPECT_EQ(1, readCallback.buffers.size());
1690   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1691
1692   server->setReadCB(&readCallback);
1693
1694   // Unencrypted
1695   server->sslAccept(nullptr);
1696   client->sslConn(nullptr);
1697
1698   // Do NOT wait for handshake, writing should be queued and happen after
1699
1700   client->write(nullptr, buf, sizeof(buf));
1701
1702   // Check that bytes are *not* unencrypted
1703   char c2;
1704   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1705   EXPECT_NE('a', c2);
1706
1707   base.loop();
1708
1709   EXPECT_EQ(2, readCallback.buffers.size());
1710   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1711 }
1712
1713 TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) {
1714   auto clientCtx = std::make_shared<folly::SSLContext>();
1715   auto serverCtx = std::make_shared<folly::SSLContext>();
1716   getctx(clientCtx, serverCtx);
1717
1718   WriteCallbackBase writeCallback;
1719   ReadCallback readCallback(&writeCallback);
1720   HandshakeCallback handshakeCallback(&readCallback);
1721   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1722   TestSSLServer server(&acceptCallback);
1723
1724   EventBase evb;
1725   std::shared_ptr<AsyncSSLSocket> socket =
1726       AsyncSSLSocket::newSocket(clientCtx, &evb, true);
1727   socket->connect(nullptr, server.getAddress(), 0);
1728
1729   evb.loop();
1730
1731   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, socket->getSSLState());
1732   socket->sslConn(nullptr);
1733   evb.loop();
1734   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, socket->getSSLState());
1735
1736   // write()
1737   std::array<uint8_t, 128> buf;
1738   memset(buf.data(), 'a', buf.size());
1739   socket->write(nullptr, buf.data(), buf.size());
1740
1741   socket->close();
1742 }
1743
1744 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1745   // Start listening on a local port
1746   WriteCallbackBase writeCallback;
1747   WriteErrorCallback readCallback(&writeCallback);
1748   HandshakeCallback handshakeCallback(
1749       &readCallback, HandshakeCallback::EXPECT_ERROR);
1750   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1751   TestSSLServer server(&acceptCallback);
1752
1753   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1754   socket->open();
1755   uint8_t buf[3] = {0x16, 0x03, 0x01};
1756   socket->write(buf, sizeof(buf));
1757   socket->closeWithReset();
1758
1759   handshakeCallback.waitForHandshake();
1760   EXPECT_NE(
1761       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1762   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1763 }
1764
1765 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1766   // Start listening on a local port
1767   WriteCallbackBase writeCallback;
1768   WriteErrorCallback readCallback(&writeCallback);
1769   HandshakeCallback handshakeCallback(
1770       &readCallback, HandshakeCallback::EXPECT_ERROR);
1771   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1772   TestSSLServer server(&acceptCallback);
1773
1774   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1775   socket->open();
1776   uint8_t buf[3] = {0x16, 0x03, 0x01};
1777   socket->write(buf, sizeof(buf));
1778   socket->close();
1779
1780   handshakeCallback.waitForHandshake();
1781 #if FOLLY_OPENSSL_IS_110
1782   EXPECT_NE(
1783       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1784 #else
1785   EXPECT_NE(
1786       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1787 #endif
1788 }
1789
1790 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1791   // Start listening on a local port
1792   WriteCallbackBase writeCallback;
1793   WriteErrorCallback readCallback(&writeCallback);
1794   HandshakeCallback handshakeCallback(
1795       &readCallback, HandshakeCallback::EXPECT_ERROR);
1796   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1797   TestSSLServer server(&acceptCallback);
1798
1799   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1800   socket->open();
1801   uint8_t buf[256] = {0x16, 0x03};
1802   memset(buf + 2, 'a', sizeof(buf) - 2);
1803   socket->write(buf, sizeof(buf));
1804   socket->close();
1805
1806   handshakeCallback.waitForHandshake();
1807   EXPECT_NE(
1808       handshakeCallback.errorString_.find("SSL routines"), std::string::npos);
1809 #if defined(OPENSSL_IS_BORINGSSL)
1810   EXPECT_NE(
1811       handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
1812       std::string::npos);
1813 #elif FOLLY_OPENSSL_IS_110
1814   EXPECT_NE(
1815       handshakeCallback.errorString_.find("packet length too long"),
1816       std::string::npos);
1817 #else
1818   EXPECT_NE(
1819       handshakeCallback.errorString_.find("unknown protocol"),
1820       std::string::npos);
1821 #endif
1822 }
1823
1824 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
1825   using folly::ssl::OpenSSLUtils;
1826   EXPECT_EQ(
1827       OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
1828   // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
1829   EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
1830   // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
1831   EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
1832 }
1833
1834 #if FOLLY_ALLOW_TFO
1835
1836 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1837  public:
1838   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1839
1840   explicit MockAsyncTFOSSLSocket(
1841       std::shared_ptr<folly::SSLContext> sslCtx,
1842       EventBase* evb)
1843       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1844
1845   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1846 };
1847
1848 /**
1849  * Test connecting to, writing to, reading from, and closing the
1850  * connection to the SSL server with TFO.
1851  */
1852 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1853   // Start listening on a local port
1854   WriteCallbackBase writeCallback;
1855   ReadCallback readCallback(&writeCallback);
1856   HandshakeCallback handshakeCallback(&readCallback);
1857   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1858   TestSSLServer server(&acceptCallback, true);
1859
1860   // Set up SSL context.
1861   auto sslContext = std::make_shared<SSLContext>();
1862
1863   // connect
1864   auto socket =
1865       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1866   socket->enableTFO();
1867   socket->open();
1868
1869   // write()
1870   std::array<uint8_t, 128> buf;
1871   memset(buf.data(), 'a', buf.size());
1872   socket->write(buf.data(), buf.size());
1873
1874   // read()
1875   std::array<uint8_t, 128> readbuf;
1876   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1877   EXPECT_EQ(bytesRead, 128);
1878   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1879
1880   // close()
1881   socket->close();
1882 }
1883
1884 /**
1885  * Test connecting to, writing to, reading from, and closing the
1886  * connection to the SSL server with TFO.
1887  */
1888 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1889   // Start listening on a local port
1890   WriteCallbackBase writeCallback;
1891   ReadCallback readCallback(&writeCallback);
1892   HandshakeCallback handshakeCallback(&readCallback);
1893   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1894   TestSSLServer server(&acceptCallback, false);
1895
1896   // Set up SSL context.
1897   auto sslContext = std::make_shared<SSLContext>();
1898
1899   // connect
1900   auto socket =
1901       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1902   socket->enableTFO();
1903   socket->open();
1904
1905   // write()
1906   std::array<uint8_t, 128> buf;
1907   memset(buf.data(), 'a', buf.size());
1908   socket->write(buf.data(), buf.size());
1909
1910   // read()
1911   std::array<uint8_t, 128> readbuf;
1912   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1913   EXPECT_EQ(bytesRead, 128);
1914   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1915
1916   // close()
1917   socket->close();
1918 }
1919
1920 class ConnCallback : public AsyncSocket::ConnectCallback {
1921  public:
1922   void connectSuccess() noexcept override {
1923     state = State::SUCCESS;
1924   }
1925
1926   void connectErr(const AsyncSocketException& ex) noexcept override {
1927     state = State::ERROR;
1928     error = ex.what();
1929   }
1930
1931   enum class State { WAITING, SUCCESS, ERROR };
1932
1933   State state{State::WAITING};
1934   std::string error;
1935 };
1936
1937 template <class Cardinality>
1938 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1939     EventBase* evb,
1940     const SocketAddress& address,
1941     Cardinality cardinality) {
1942   // Set up SSL context.
1943   auto sslContext = std::make_shared<SSLContext>();
1944
1945   // connect
1946   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1947       new MockAsyncTFOSSLSocket(sslContext, evb));
1948   socket->enableTFO();
1949
1950   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1951       .Times(cardinality)
1952       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1953         sockaddr_storage addr;
1954         auto len = address.getAddress(&addr);
1955         return connect(fd, (const struct sockaddr*)&addr, len);
1956       }));
1957   return socket;
1958 }
1959
1960 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1961   // Start listening on a local port
1962   WriteCallbackBase writeCallback;
1963   ReadCallback readCallback(&writeCallback);
1964   HandshakeCallback handshakeCallback(&readCallback);
1965   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1966   TestSSLServer server(&acceptCallback, true);
1967
1968   EventBase evb;
1969
1970   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1971   ConnCallback ccb;
1972   socket->connect(&ccb, server.getAddress(), 30);
1973
1974   evb.loop();
1975   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1976
1977   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1978   evb.loop();
1979
1980   BlockingSocket sock(std::move(socket));
1981   // write()
1982   std::array<uint8_t, 128> buf;
1983   memset(buf.data(), 'a', buf.size());
1984   sock.write(buf.data(), buf.size());
1985
1986   // read()
1987   std::array<uint8_t, 128> readbuf;
1988   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1989   EXPECT_EQ(bytesRead, 128);
1990   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1991
1992   // close()
1993   sock.close();
1994 }
1995
1996 #if !defined(OPENSSL_IS_BORINGSSL)
1997 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1998   // Start listening on a local port
1999   ConnectTimeoutCallback acceptCallback;
2000   TestSSLServer server(&acceptCallback, true);
2001
2002   // Set up SSL context.
2003   auto sslContext = std::make_shared<SSLContext>();
2004
2005   // connect
2006   auto socket =
2007       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2008   socket->enableTFO();
2009   EXPECT_THROW(
2010       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
2011 }
2012 #endif
2013
2014 #if !defined(OPENSSL_IS_BORINGSSL)
2015 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
2016   // Start listening on a local port
2017   ConnectTimeoutCallback acceptCallback;
2018   TestSSLServer server(&acceptCallback, true);
2019
2020   EventBase evb;
2021
2022   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
2023   ConnCallback ccb;
2024   // Set a short timeout
2025   socket->connect(&ccb, server.getAddress(), 1);
2026
2027   evb.loop();
2028   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2029 }
2030 #endif
2031
2032 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
2033   // Start listening on a local port
2034   EmptyReadCallback readCallback;
2035   HandshakeCallback handshakeCallback(
2036       &readCallback, HandshakeCallback::EXPECT_ERROR);
2037   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
2038   TestSSLServer server(&acceptCallback, true);
2039
2040   EventBase evb;
2041
2042   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
2043   ConnCallback ccb;
2044   socket->connect(&ccb, server.getAddress(), 100);
2045
2046   evb.loop();
2047   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2048   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
2049 }
2050
2051 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
2052   // Start listening on a local port
2053   EventBase evb;
2054
2055   // Hopefully nothing is listening on this address
2056   SocketAddress addr("127.0.0.1", 65535);
2057   auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
2058   ConnCallback ccb;
2059   socket->connect(&ccb, addr, 100);
2060
2061   evb.loop();
2062   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
2063   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
2064 }
2065
2066 TEST(AsyncSSLSocketTest, TestPreReceivedData) {
2067   EventBase clientEventBase;
2068   EventBase serverEventBase;
2069   auto clientCtx = std::make_shared<SSLContext>();
2070   auto dfServerCtx = std::make_shared<SSLContext>();
2071   std::array<int, 2> fds;
2072   getfds(fds.data());
2073   getctx(clientCtx, dfServerCtx);
2074
2075   AsyncSSLSocket::UniquePtr clientSockPtr(
2076       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
2077   AsyncSSLSocket::UniquePtr serverSockPtr(
2078       new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
2079   auto clientSock = clientSockPtr.get();
2080   auto serverSock = serverSockPtr.get();
2081   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
2082
2083   // Steal some data from the server.
2084   clientEventBase.loopOnce();
2085   std::array<uint8_t, 10> buf;
2086   recv(fds[1], buf.data(), buf.size(), 0);
2087
2088   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
2089   SSLHandshakeServer server(std::move(serverSockPtr), true, true);
2090   while (!client.handshakeSuccess_ && !client.handshakeError_) {
2091     serverEventBase.loopOnce();
2092     clientEventBase.loopOnce();
2093   }
2094
2095   EXPECT_TRUE(client.handshakeSuccess_);
2096   EXPECT_TRUE(server.handshakeSuccess_);
2097   EXPECT_EQ(
2098       serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
2099 }
2100
2101 TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) {
2102   EventBase clientEventBase;
2103   EventBase serverEventBase;
2104   auto clientCtx = std::make_shared<SSLContext>();
2105   auto dfServerCtx = std::make_shared<SSLContext>();
2106   std::array<int, 2> fds;
2107   getfds(fds.data());
2108   getctx(clientCtx, dfServerCtx);
2109
2110   AsyncSSLSocket::UniquePtr clientSockPtr(
2111       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
2112   AsyncSocket::UniquePtr serverSockPtr(
2113       new AsyncSocket(&serverEventBase, fds[1]));
2114   auto clientSock = clientSockPtr.get();
2115   auto serverSock = serverSockPtr.get();
2116   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
2117
2118   // Steal some data from the server.
2119   clientEventBase.loopOnce();
2120   std::array<uint8_t, 10> buf;
2121   recv(fds[1], buf.data(), buf.size(), 0);
2122   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
2123   AsyncSSLSocket::UniquePtr serverSSLSockPtr(
2124       new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true));
2125   auto serverSSLSock = serverSSLSockPtr.get();
2126   SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true);
2127   while (!client.handshakeSuccess_ && !client.handshakeError_) {
2128     serverEventBase.loopOnce();
2129     clientEventBase.loopOnce();
2130   }
2131
2132   EXPECT_TRUE(client.handshakeSuccess_);
2133   EXPECT_TRUE(server.handshakeSuccess_);
2134   EXPECT_EQ(
2135       serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
2136 }
2137
2138 /**
2139  * Test overriding the flags passed to "sendmsg()" system call,
2140  * and verifying that write requests fail properly.
2141  */
2142 TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
2143   // Start listening on a local port
2144   SendMsgFlagsCallback msgCallback;
2145   ExpectWriteErrorCallback writeCallback(&msgCallback);
2146   ReadCallback readCallback(&writeCallback);
2147   HandshakeCallback handshakeCallback(&readCallback);
2148   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2149   TestSSLServer server(&acceptCallback);
2150
2151   // Set up SSL context.
2152   auto sslContext = std::make_shared<SSLContext>();
2153   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2154
2155   // connect
2156   auto socket =
2157       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2158   socket->open();
2159
2160   // Setting flags to "-1" to trigger "Invalid argument" error
2161   // on attempt to use this flags in sendmsg() system call.
2162   msgCallback.resetFlags(-1);
2163
2164   // write()
2165   std::vector<uint8_t> buf(128, 'a');
2166   ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
2167
2168   // close()
2169   socket->close();
2170
2171   cerr << "SendMsgParamsCallback test completed" << endl;
2172 }
2173
2174 #ifdef MSG_ERRQUEUE
2175 /**
2176  * Test connecting to, writing to, reading from, and closing the
2177  * connection to the SSL server.
2178  */
2179 TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
2180   // This test requires Linux kernel v4.6 or later
2181   struct utsname s_uname;
2182   memset(&s_uname, 0, sizeof(s_uname));
2183   ASSERT_EQ(uname(&s_uname), 0);
2184   int major, minor;
2185   folly::StringPiece extra;
2186   if (folly::split<false>(
2187           '.', std::string(s_uname.release) + ".", major, minor, extra)) {
2188     if (major < 4 || (major == 4 && minor < 6)) {
2189       LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
2190                 << "kernel ver. " << s_uname.release << " detected).";
2191       return;
2192     }
2193   }
2194
2195   // Start listening on a local port
2196   SendMsgDataCallback msgCallback;
2197   WriteCheckTimestampCallback writeCallback(&msgCallback);
2198   ReadCallback readCallback(&writeCallback);
2199   HandshakeCallback handshakeCallback(&readCallback);
2200   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2201   TestSSLServer server(&acceptCallback);
2202
2203   // Set up SSL context.
2204   auto sslContext = std::make_shared<SSLContext>();
2205   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2206
2207   // connect
2208   auto socket =
2209       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
2210   socket->open();
2211
2212   // Adding MSG_EOR flag to the message flags - it'll trigger
2213   // timestamp generation for the last byte of the message.
2214   msgCallback.resetFlags(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_EOR);
2215
2216   // Init ancillary data buffer to trigger timestamp notification
2217   union {
2218     uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
2219     struct cmsghdr cmsg;
2220   } u;
2221   u.cmsg.cmsg_level = SOL_SOCKET;
2222   u.cmsg.cmsg_type = SO_TIMESTAMPING;
2223   u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
2224   uint32_t flags = SOF_TIMESTAMPING_TX_SCHED | SOF_TIMESTAMPING_TX_SOFTWARE |
2225       SOF_TIMESTAMPING_TX_ACK;
2226   memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
2227   std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
2228   memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
2229   msgCallback.resetData(std::move(ctrl));
2230
2231   // write()
2232   std::vector<uint8_t> buf(128, 'a');
2233   socket->write(buf.data(), buf.size());
2234
2235   // read()
2236   std::vector<uint8_t> readbuf(buf.size());
2237   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2238   EXPECT_EQ(bytesRead, buf.size());
2239   EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
2240
2241   writeCallback.checkForTimestampNotifications();
2242
2243   // close()
2244   socket->close();
2245
2246   cerr << "SendMsgDataCallback test completed" << endl;
2247 }
2248 #endif // MSG_ERRQUEUE
2249
2250 #endif
2251
2252 } // namespace
2253
2254 #ifdef SIGPIPE
2255 ///////////////////////////////////////////////////////////////////////////
2256 // init_unit_test_suite
2257 ///////////////////////////////////////////////////////////////////////////
2258 namespace {
2259 struct Initializer {
2260   Initializer() {
2261     signal(SIGPIPE, SIG_IGN);
2262   }
2263 };
2264 Initializer initializer;
2265 } // namespace
2266 #endif