AsyncSSLSocket connect without SSL
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <folly/SocketAddress.h>
19 #include <folly/io/Cursor.h>
20 #include <folly/io/async/AsyncSSLSocket.h>
21 #include <folly/io/async/EventBase.h>
22 #include <folly/portability/GMock.h>
23 #include <folly/portability/GTest.h>
24 #include <folly/portability/OpenSSL.h>
25 #include <folly/portability/Sockets.h>
26 #include <folly/portability/Unistd.h>
27
28 #include <folly/io/async/test/BlockingSocket.h>
29
30 #include <fcntl.h>
31 #include <signal.h>
32 #include <sys/types.h>
33 #include <sys/utsname.h>
34
35 #include <fstream>
36 #include <iostream>
37 #include <list>
38 #include <set>
39 #include <thread>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 using namespace testing;
49
50 namespace folly {
51 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
52 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
53 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
54
55 constexpr size_t SSLClient::kMaxReadBufferSz;
56 constexpr size_t SSLClient::kMaxReadsPerEvent;
57
58 void getfds(int fds[2]) {
59   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
60     FAIL() << "failed to create socketpair: " << strerror(errno);
61   }
62   for (int idx = 0; idx < 2; ++idx) {
63     int flags = fcntl(fds[idx], F_GETFL, 0);
64     if (flags == -1) {
65       FAIL() << "failed to get flags for socket " << idx << ": "
66              << strerror(errno);
67     }
68     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
69       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
70              << strerror(errno);
71     }
72   }
73 }
74
75 void getctx(
76   std::shared_ptr<folly::SSLContext> clientCtx,
77   std::shared_ptr<folly::SSLContext> serverCtx) {
78   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
79
80   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
81   serverCtx->loadCertificate(kTestCert);
82   serverCtx->loadPrivateKey(kTestKey);
83 }
84
85 void sslsocketpair(
86   EventBase* eventBase,
87   AsyncSSLSocket::UniquePtr* clientSock,
88   AsyncSSLSocket::UniquePtr* serverSock) {
89   auto clientCtx = std::make_shared<folly::SSLContext>();
90   auto serverCtx = std::make_shared<folly::SSLContext>();
91   int fds[2];
92   getfds(fds);
93   getctx(clientCtx, serverCtx);
94   clientSock->reset(new AsyncSSLSocket(
95                       clientCtx, eventBase, fds[0], false));
96   serverSock->reset(new AsyncSSLSocket(
97                       serverCtx, eventBase, fds[1], true));
98
99   // (*clientSock)->setSendTimeout(100);
100   // (*serverSock)->setSendTimeout(100);
101 }
102
103 // client protocol filters
104 bool clientProtoFilterPickPony(unsigned char** client,
105   unsigned int* client_len, const unsigned char*, unsigned int ) {
106   //the protocol string in length prefixed byte string. the
107   //length byte is not included in the length
108   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
109   *client = p;
110   *client_len = 7;
111   return true;
112 }
113
114 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
115   const unsigned char*, unsigned int) {
116   return false;
117 }
118
119 std::string getFileAsBuf(const char* fileName) {
120   std::string buffer;
121   folly::readFile(fileName, buffer);
122   return buffer;
123 }
124
125 std::string getCommonName(X509* cert) {
126   X509_NAME* subject = X509_get_subject_name(cert);
127   std::string cn;
128   cn.resize(ub_common_name);
129   X509_NAME_get_text_by_NID(
130       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
131   return cn;
132 }
133
134 /**
135  * Test connecting to, writing to, reading from, and closing the
136  * connection to the SSL server.
137  */
138 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
139   // Start listening on a local port
140   WriteCallbackBase writeCallback;
141   ReadCallback readCallback(&writeCallback);
142   HandshakeCallback handshakeCallback(&readCallback);
143   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
144   TestSSLServer server(&acceptCallback);
145
146   // Set up SSL context.
147   std::shared_ptr<SSLContext> sslContext(new SSLContext());
148   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
149   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
150   //sslContext->authenticate(true, false);
151
152   // connect
153   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
154                                                  sslContext);
155   socket->open(std::chrono::milliseconds(10000));
156
157   // write()
158   uint8_t buf[128];
159   memset(buf, 'a', sizeof(buf));
160   socket->write(buf, sizeof(buf));
161
162   // read()
163   uint8_t readbuf[128];
164   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
165   EXPECT_EQ(bytesRead, 128);
166   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
167
168   // close()
169   socket->close();
170
171   cerr << "ConnectWriteReadClose test completed" << endl;
172   EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
173 }
174
175 /**
176  * Test reading after server close.
177  */
178 TEST(AsyncSSLSocketTest, ReadAfterClose) {
179   // Start listening on a local port
180   WriteCallbackBase writeCallback;
181   ReadEOFCallback readCallback(&writeCallback);
182   HandshakeCallback handshakeCallback(&readCallback);
183   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
184   auto server = std::make_unique<TestSSLServer>(&acceptCallback);
185
186   // Set up SSL context.
187   auto sslContext = std::make_shared<SSLContext>();
188   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
189
190   auto socket =
191       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
192   socket->open();
193
194   // This should trigger an EOF on the client.
195   auto evb = handshakeCallback.getSocket()->getEventBase();
196   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
197   std::array<uint8_t, 128> readbuf;
198   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
199   EXPECT_EQ(0, bytesRead);
200 }
201
202 /**
203  * Test bad renegotiation
204  */
205 #if !defined(OPENSSL_IS_BORINGSSL)
206 TEST(AsyncSSLSocketTest, Renegotiate) {
207   EventBase eventBase;
208   auto clientCtx = std::make_shared<SSLContext>();
209   auto dfServerCtx = std::make_shared<SSLContext>();
210   std::array<int, 2> fds;
211   getfds(fds.data());
212   getctx(clientCtx, dfServerCtx);
213
214   AsyncSSLSocket::UniquePtr clientSock(
215       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
216   AsyncSSLSocket::UniquePtr serverSock(
217       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
218   SSLHandshakeClient client(std::move(clientSock), true, true);
219   RenegotiatingServer server(std::move(serverSock));
220
221   while (!client.handshakeSuccess_ && !client.handshakeError_) {
222     eventBase.loopOnce();
223   }
224
225   ASSERT_TRUE(client.handshakeSuccess_);
226
227   auto sslSock = std::move(client).moveSocket();
228   sslSock->detachEventBase();
229   // This is nasty, however we don't want to add support for
230   // renegotiation in AsyncSSLSocket.
231   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
232
233   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
234
235   std::thread t([&]() { eventBase.loopForever(); });
236
237   // Trigger the renegotiation.
238   std::array<uint8_t, 128> buf;
239   memset(buf.data(), 'a', buf.size());
240   try {
241     socket->write(buf.data(), buf.size());
242   } catch (AsyncSocketException& e) {
243     LOG(INFO) << "client got error " << e.what();
244   }
245   eventBase.terminateLoopSoon();
246   t.join();
247
248   eventBase.loop();
249   ASSERT_TRUE(server.renegotiationError_);
250 }
251 #endif
252
253 /**
254  * Negative test for handshakeError().
255  */
256 TEST(AsyncSSLSocketTest, HandshakeError) {
257   // Start listening on a local port
258   WriteCallbackBase writeCallback;
259   WriteErrorCallback readCallback(&writeCallback);
260   HandshakeCallback handshakeCallback(&readCallback);
261   HandshakeErrorCallback acceptCallback(&handshakeCallback);
262   TestSSLServer server(&acceptCallback);
263
264   // Set up SSL context.
265   std::shared_ptr<SSLContext> sslContext(new SSLContext());
266   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
267
268   // connect
269   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
270                                                  sslContext);
271   // read()
272   bool ex = false;
273   try {
274     socket->open();
275
276     uint8_t readbuf[128];
277     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
278     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
279   } catch (AsyncSocketException&) {
280     ex = true;
281   }
282   EXPECT_TRUE(ex);
283
284   // close()
285   socket->close();
286   cerr << "HandshakeError test completed" << endl;
287 }
288
289 /**
290  * Negative test for readError().
291  */
292 TEST(AsyncSSLSocketTest, ReadError) {
293   // Start listening on a local port
294   WriteCallbackBase writeCallback;
295   ReadErrorCallback readCallback(&writeCallback);
296   HandshakeCallback handshakeCallback(&readCallback);
297   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
298   TestSSLServer server(&acceptCallback);
299
300   // Set up SSL context.
301   std::shared_ptr<SSLContext> sslContext(new SSLContext());
302   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
303
304   // connect
305   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
306                                                  sslContext);
307   socket->open();
308
309   // write something to trigger ssl handshake
310   uint8_t buf[128];
311   memset(buf, 'a', sizeof(buf));
312   socket->write(buf, sizeof(buf));
313
314   socket->close();
315   cerr << "ReadError test completed" << endl;
316 }
317
318 /**
319  * Negative test for writeError().
320  */
321 TEST(AsyncSSLSocketTest, WriteError) {
322   // Start listening on a local port
323   WriteCallbackBase writeCallback;
324   WriteErrorCallback readCallback(&writeCallback);
325   HandshakeCallback handshakeCallback(&readCallback);
326   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
327   TestSSLServer server(&acceptCallback);
328
329   // Set up SSL context.
330   std::shared_ptr<SSLContext> sslContext(new SSLContext());
331   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
332
333   // connect
334   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
335                                                  sslContext);
336   socket->open();
337
338   // write something to trigger ssl handshake
339   uint8_t buf[128];
340   memset(buf, 'a', sizeof(buf));
341   socket->write(buf, sizeof(buf));
342
343   socket->close();
344   cerr << "WriteError test completed" << endl;
345 }
346
347 /**
348  * Test a socket with TCP_NODELAY unset.
349  */
350 TEST(AsyncSSLSocketTest, SocketWithDelay) {
351   // Start listening on a local port
352   WriteCallbackBase writeCallback;
353   ReadCallback readCallback(&writeCallback);
354   HandshakeCallback handshakeCallback(&readCallback);
355   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
356   TestSSLServer server(&acceptCallback);
357
358   // Set up SSL context.
359   std::shared_ptr<SSLContext> sslContext(new SSLContext());
360   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
361
362   // connect
363   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
364                                                  sslContext);
365   socket->open();
366
367   // write()
368   uint8_t buf[128];
369   memset(buf, 'a', sizeof(buf));
370   socket->write(buf, sizeof(buf));
371
372   // read()
373   uint8_t readbuf[128];
374   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
375   EXPECT_EQ(bytesRead, 128);
376   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
377
378   // close()
379   socket->close();
380
381   cerr << "SocketWithDelay test completed" << endl;
382 }
383
384 using NextProtocolTypePair =
385     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
386
387 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
388   // For matching protos
389  public:
390   void SetUp() override { getctx(clientCtx, serverCtx); }
391
392   void connect(bool unset = false) {
393     getfds(fds);
394
395     if (unset) {
396       // unsetting NPN for any of [client, server] is enough to make NPN not
397       // work
398       clientCtx->unsetNextProtocols();
399     }
400
401     AsyncSSLSocket::UniquePtr clientSock(
402       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
403     AsyncSSLSocket::UniquePtr serverSock(
404       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
405     client = std::make_unique<NpnClient>(std::move(clientSock));
406     server = std::make_unique<NpnServer>(std::move(serverSock));
407
408     eventBase.loop();
409   }
410
411   void expectProtocol(const std::string& proto) {
412     expectHandshakeSuccess();
413     EXPECT_NE(client->nextProtoLength, 0);
414     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
415     EXPECT_EQ(
416         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
417         0);
418     string selected((const char*)client->nextProto, client->nextProtoLength);
419     EXPECT_EQ(proto, selected);
420   }
421
422   void expectNoProtocol() {
423     expectHandshakeSuccess();
424     EXPECT_EQ(client->nextProtoLength, 0);
425     EXPECT_EQ(server->nextProtoLength, 0);
426     EXPECT_EQ(client->nextProto, nullptr);
427     EXPECT_EQ(server->nextProto, nullptr);
428   }
429
430   void expectProtocolType() {
431     expectHandshakeSuccess();
432     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
433         GetParam().second == SSLContext::NextProtocolType::ANY) {
434       EXPECT_EQ(client->protocolType, server->protocolType);
435     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
436                GetParam().second == SSLContext::NextProtocolType::ANY) {
437       // Well not much we can say
438     } else {
439       expectProtocolType(GetParam());
440     }
441   }
442
443   void expectProtocolType(NextProtocolTypePair expected) {
444     expectHandshakeSuccess();
445     EXPECT_EQ(client->protocolType, expected.first);
446     EXPECT_EQ(server->protocolType, expected.second);
447   }
448
449   void expectHandshakeSuccess() {
450     EXPECT_FALSE(client->except.hasValue())
451         << "client handshake error: " << client->except->what();
452     EXPECT_FALSE(server->except.hasValue())
453         << "server handshake error: " << server->except->what();
454   }
455
456   void expectHandshakeError() {
457     EXPECT_TRUE(client->except.hasValue())
458         << "Expected client handshake error!";
459     EXPECT_TRUE(server->except.hasValue())
460         << "Expected server handshake error!";
461   }
462
463   EventBase eventBase;
464   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
465   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
466   int fds[2];
467   std::unique_ptr<NpnClient> client;
468   std::unique_ptr<NpnServer> server;
469 };
470
471 class NextProtocolTLSExtTest : public NextProtocolTest {
472   // For extended TLS protos
473 };
474
475 class NextProtocolNPNOnlyTest : public NextProtocolTest {
476   // For mismatching protos
477 };
478
479 class NextProtocolMismatchTest : public NextProtocolTest {
480   // For mismatching protos
481 };
482
483 TEST_P(NextProtocolTest, NpnTestOverlap) {
484   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
485   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
486                                         GetParam().second);
487
488   connect();
489
490   expectProtocol("baz");
491   expectProtocolType();
492 }
493
494 TEST_P(NextProtocolTest, NpnTestUnset) {
495   // Identical to above test, except that we want unset NPN before
496   // looping.
497   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
498   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
499                                         GetParam().second);
500
501   connect(true /* unset */);
502
503   // if alpn negotiation fails, type will appear as npn
504   expectNoProtocol();
505   EXPECT_EQ(client->protocolType, server->protocolType);
506 }
507
508 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
509   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
510   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
511                                         GetParam().second);
512
513   connect();
514
515   expectNoProtocol();
516   expectProtocolType(
517       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
518 }
519
520 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
521 // will fail on 1.0.2 before that.
522 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
523   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
524   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
525                                         GetParam().second);
526   connect();
527
528   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
529       GetParam().second == SSLContext::NextProtocolType::ALPN) {
530     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
531     // mismatch should result in a fatal alert, but this is the current behavior
532     // on all OpenSSL versions/variants, and we want to know if it changes.
533     expectNoProtocol();
534   }
535 #if FOLLY_OPENSSL_IS_110 || defined(OPENSSL_IS_BORINGSSL)
536   else if (
537       GetParam().first == SSLContext::NextProtocolType::ANY &&
538       GetParam().second == SSLContext::NextProtocolType::ANY) {
539 # if FOLLY_OPENSSL_IS_110
540     // OpenSSL 1.1.0 sends a fatal alert on mismatch, which is probavbly the
541     // correct behavior per RFC7301
542     expectHandshakeError();
543 # else
544     // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
545     // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
546     // NPN *and* ALPN
547     expectNoProtocol();
548 # endif
549   }
550 #endif
551    else {
552     expectProtocol("blub");
553     expectProtocolType(
554         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
555   }
556 }
557
558 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
559   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
560   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
561   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
562                                         GetParam().second);
563
564   connect();
565
566   expectProtocol("ponies");
567   expectProtocolType();
568 }
569
570 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
571   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
572   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
573   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
574                                         GetParam().second);
575
576   connect();
577
578   expectProtocol("blub");
579   expectProtocolType();
580 }
581
582 TEST_P(NextProtocolTest, RandomizedNpnTest) {
583   // Probability that this test will fail is 2^-64, which could be considered
584   // as negligible.
585   const int kTries = 64;
586
587   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
588                                         GetParam().first);
589   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
590                                                   GetParam().second);
591
592   std::set<string> selectedProtocols;
593   for (int i = 0; i < kTries; ++i) {
594     connect();
595
596     EXPECT_NE(client->nextProtoLength, 0);
597     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
598     EXPECT_EQ(
599         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
600         0);
601     string selected((const char*)client->nextProto, client->nextProtoLength);
602     selectedProtocols.insert(selected);
603     expectProtocolType();
604   }
605   EXPECT_EQ(selectedProtocols.size(), 2);
606 }
607
608 INSTANTIATE_TEST_CASE_P(
609     AsyncSSLSocketTest,
610     NextProtocolTest,
611     ::testing::Values(
612         NextProtocolTypePair(
613             SSLContext::NextProtocolType::NPN,
614             SSLContext::NextProtocolType::NPN),
615         NextProtocolTypePair(
616             SSLContext::NextProtocolType::NPN,
617             SSLContext::NextProtocolType::ANY),
618         NextProtocolTypePair(
619             SSLContext::NextProtocolType::ANY,
620             SSLContext::NextProtocolType::ANY)));
621
622 #if FOLLY_OPENSSL_HAS_ALPN
623 INSTANTIATE_TEST_CASE_P(
624     AsyncSSLSocketTest,
625     NextProtocolTLSExtTest,
626     ::testing::Values(
627         NextProtocolTypePair(
628             SSLContext::NextProtocolType::ALPN,
629             SSLContext::NextProtocolType::ALPN),
630         NextProtocolTypePair(
631             SSLContext::NextProtocolType::ALPN,
632             SSLContext::NextProtocolType::ANY),
633         NextProtocolTypePair(
634             SSLContext::NextProtocolType::ANY,
635             SSLContext::NextProtocolType::ALPN)));
636 #endif
637
638 INSTANTIATE_TEST_CASE_P(
639     AsyncSSLSocketTest,
640     NextProtocolNPNOnlyTest,
641     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
642                                            SSLContext::NextProtocolType::NPN)));
643
644 #if FOLLY_OPENSSL_HAS_ALPN
645 INSTANTIATE_TEST_CASE_P(
646     AsyncSSLSocketTest,
647     NextProtocolMismatchTest,
648     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
649                                            SSLContext::NextProtocolType::ALPN),
650                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
651                                            SSLContext::NextProtocolType::NPN)));
652 #endif
653
654 #ifndef OPENSSL_NO_TLSEXT
655 /**
656  * 1. Client sends TLSEXT_HOSTNAME in client hello.
657  * 2. Server found a match SSL_CTX and use this SSL_CTX to
658  *    continue the SSL handshake.
659  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
660  */
661 TEST(AsyncSSLSocketTest, SNITestMatch) {
662   EventBase eventBase;
663   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
664   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
665   // Use the same SSLContext to continue the handshake after
666   // tlsext_hostname match.
667   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
668   const std::string serverName("xyz.newdev.facebook.com");
669   int fds[2];
670   getfds(fds);
671   getctx(clientCtx, dfServerCtx);
672
673   AsyncSSLSocket::UniquePtr clientSock(
674     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
675   AsyncSSLSocket::UniquePtr serverSock(
676     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
677   SNIClient client(std::move(clientSock));
678   SNIServer server(std::move(serverSock),
679                    dfServerCtx,
680                    hskServerCtx,
681                    serverName);
682
683   eventBase.loop();
684
685   EXPECT_TRUE(client.serverNameMatch);
686   EXPECT_TRUE(server.serverNameMatch);
687 }
688
689 /**
690  * 1. Client sends TLSEXT_HOSTNAME in client hello.
691  * 2. Server cannot find a matching SSL_CTX and continue to use
692  *    the current SSL_CTX to do the handshake.
693  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
694  */
695 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
696   EventBase eventBase;
697   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
698   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
699   // Use the same SSLContext to continue the handshake after
700   // tlsext_hostname match.
701   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
702   const std::string clientRequestingServerName("foo.com");
703   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
704
705   int fds[2];
706   getfds(fds);
707   getctx(clientCtx, dfServerCtx);
708
709   AsyncSSLSocket::UniquePtr clientSock(
710     new AsyncSSLSocket(clientCtx,
711                         &eventBase,
712                         fds[0],
713                         clientRequestingServerName));
714   AsyncSSLSocket::UniquePtr serverSock(
715     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
716   SNIClient client(std::move(clientSock));
717   SNIServer server(std::move(serverSock),
718                    dfServerCtx,
719                    hskServerCtx,
720                    serverExpectedServerName);
721
722   eventBase.loop();
723
724   EXPECT_TRUE(!client.serverNameMatch);
725   EXPECT_TRUE(!server.serverNameMatch);
726 }
727 /**
728  * 1. Client sends TLSEXT_HOSTNAME in client hello.
729  * 2. We then change the serverName.
730  * 3. We expect that we get 'false' as the result for serNameMatch.
731  */
732
733 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
734    EventBase eventBase;
735   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
736   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
737   // Use the same SSLContext to continue the handshake after
738   // tlsext_hostname match.
739   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
740   const std::string serverName("xyz.newdev.facebook.com");
741   int fds[2];
742   getfds(fds);
743   getctx(clientCtx, dfServerCtx);
744
745   AsyncSSLSocket::UniquePtr clientSock(
746     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
747   //Change the server name
748   std::string newName("new.com");
749   clientSock->setServerName(newName);
750   AsyncSSLSocket::UniquePtr serverSock(
751     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
752   SNIClient client(std::move(clientSock));
753   SNIServer server(std::move(serverSock),
754                    dfServerCtx,
755                    hskServerCtx,
756                    serverName);
757
758   eventBase.loop();
759
760   EXPECT_TRUE(!client.serverNameMatch);
761 }
762
763 /**
764  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
765  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
766  */
767 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
768   EventBase eventBase;
769   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
770   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
771   // Use the same SSLContext to continue the handshake after
772   // tlsext_hostname match.
773   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
774   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
775
776   int fds[2];
777   getfds(fds);
778   getctx(clientCtx, dfServerCtx);
779
780   AsyncSSLSocket::UniquePtr clientSock(
781     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
782   AsyncSSLSocket::UniquePtr serverSock(
783     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
784   SNIClient client(std::move(clientSock));
785   SNIServer server(std::move(serverSock),
786                    dfServerCtx,
787                    hskServerCtx,
788                    serverExpectedServerName);
789
790   eventBase.loop();
791
792   EXPECT_TRUE(!client.serverNameMatch);
793   EXPECT_TRUE(!server.serverNameMatch);
794 }
795
796 #endif
797 /**
798  * Test SSL client socket
799  */
800 TEST(AsyncSSLSocketTest, SSLClientTest) {
801   // Start listening on a local port
802   WriteCallbackBase writeCallback;
803   ReadCallback readCallback(&writeCallback);
804   HandshakeCallback handshakeCallback(&readCallback);
805   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
806   TestSSLServer server(&acceptCallback);
807
808   // Set up SSL client
809   EventBase eventBase;
810   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
811
812   client->connect();
813   EventBaseAborter eba(&eventBase, 3000);
814   eventBase.loop();
815
816   EXPECT_EQ(client->getMiss(), 1);
817   EXPECT_EQ(client->getHit(), 0);
818
819   cerr << "SSLClientTest test completed" << endl;
820 }
821
822
823 /**
824  * Test SSL client socket session re-use
825  */
826 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
827   // Start listening on a local port
828   WriteCallbackBase writeCallback;
829   ReadCallback readCallback(&writeCallback);
830   HandshakeCallback handshakeCallback(&readCallback);
831   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
832   TestSSLServer server(&acceptCallback);
833
834   // Set up SSL client
835   EventBase eventBase;
836   auto client =
837       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
838
839   client->connect();
840   EventBaseAborter eba(&eventBase, 3000);
841   eventBase.loop();
842
843   EXPECT_EQ(client->getMiss(), 1);
844   EXPECT_EQ(client->getHit(), 9);
845
846   cerr << "SSLClientTestReuse test completed" << endl;
847 }
848
849 /**
850  * Test SSL client socket timeout
851  */
852 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
853   // Start listening on a local port
854   EmptyReadCallback readCallback;
855   HandshakeCallback handshakeCallback(&readCallback,
856                                       HandshakeCallback::EXPECT_ERROR);
857   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
858   TestSSLServer server(&acceptCallback);
859
860   // Set up SSL client
861   EventBase eventBase;
862   auto client =
863       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
864   client->connect(true /* write before connect completes */);
865   EventBaseAborter eba(&eventBase, 3000);
866   eventBase.loop();
867
868   usleep(100000);
869   // This is checking that the connectError callback precedes any queued
870   // writeError callbacks.  This matches AsyncSocket's behavior
871   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
872   EXPECT_EQ(client->getErrors(), 1);
873   EXPECT_EQ(client->getMiss(), 0);
874   EXPECT_EQ(client->getHit(), 0);
875
876   cerr << "SSLClientTimeoutTest test completed" << endl;
877 }
878
879 // The next 3 tests need an FB-only extension, and will fail without it
880 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
881 /**
882  * Test SSL server async cache
883  */
884 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
885   // Start listening on a local port
886   WriteCallbackBase writeCallback;
887   ReadCallback readCallback(&writeCallback);
888   HandshakeCallback handshakeCallback(&readCallback);
889   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
890   TestSSLAsyncCacheServer server(&acceptCallback);
891
892   // Set up SSL client
893   EventBase eventBase;
894   auto client =
895       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
896
897   client->connect();
898   EventBaseAborter eba(&eventBase, 3000);
899   eventBase.loop();
900
901   EXPECT_EQ(server.getAsyncCallbacks(), 18);
902   EXPECT_EQ(server.getAsyncLookups(), 9);
903   EXPECT_EQ(client->getMiss(), 10);
904   EXPECT_EQ(client->getHit(), 0);
905
906   cerr << "SSLServerAsyncCacheTest test completed" << endl;
907 }
908
909 /**
910  * Test SSL server accept timeout with cache path
911  */
912 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
913   // Start listening on a local port
914   WriteCallbackBase writeCallback;
915   ReadCallback readCallback(&writeCallback);
916   HandshakeCallback handshakeCallback(&readCallback);
917   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
918   TestSSLAsyncCacheServer server(&acceptCallback);
919
920   // Set up SSL client
921   EventBase eventBase;
922   // only do a TCP connect
923   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
924   sock->connect(nullptr, server.getAddress());
925
926   EmptyReadCallback clientReadCallback;
927   clientReadCallback.tcpSocket_ = sock;
928   sock->setReadCB(&clientReadCallback);
929
930   EventBaseAborter eba(&eventBase, 3000);
931   eventBase.loop();
932
933   EXPECT_EQ(readCallback.state, STATE_WAITING);
934
935   cerr << "SSLServerTimeoutTest test completed" << endl;
936 }
937
938 /**
939  * Test SSL server accept timeout with cache path
940  */
941 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
942   // Start listening on a local port
943   WriteCallbackBase writeCallback;
944   ReadCallback readCallback(&writeCallback);
945   HandshakeCallback handshakeCallback(&readCallback);
946   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
947   TestSSLAsyncCacheServer server(&acceptCallback);
948
949   // Set up SSL client
950   EventBase eventBase;
951   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
952
953   client->connect();
954   EventBaseAborter eba(&eventBase, 3000);
955   eventBase.loop();
956
957   EXPECT_EQ(server.getAsyncCallbacks(), 1);
958   EXPECT_EQ(server.getAsyncLookups(), 1);
959   EXPECT_EQ(client->getErrors(), 1);
960   EXPECT_EQ(client->getMiss(), 1);
961   EXPECT_EQ(client->getHit(), 0);
962
963   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
964 }
965
966 /**
967  * Test SSL server accept timeout with cache path
968  */
969 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
970   // Start listening on a local port
971   WriteCallbackBase writeCallback;
972   ReadCallback readCallback(&writeCallback);
973   HandshakeCallback handshakeCallback(&readCallback,
974                                       HandshakeCallback::EXPECT_ERROR);
975   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
976   TestSSLAsyncCacheServer server(&acceptCallback, 500);
977
978   // Set up SSL client
979   EventBase eventBase;
980   auto client =
981       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
982
983   client->connect();
984   EventBaseAborter eba(&eventBase, 3000);
985   eventBase.loop();
986
987   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
988       handshakeCallback.closeSocket();});
989   // give time for the cache lookup to come back and find it closed
990   handshakeCallback.waitForHandshake();
991
992   EXPECT_EQ(server.getAsyncCallbacks(), 1);
993   EXPECT_EQ(server.getAsyncLookups(), 1);
994   EXPECT_EQ(client->getErrors(), 1);
995   EXPECT_EQ(client->getMiss(), 1);
996   EXPECT_EQ(client->getHit(), 0);
997
998   cerr << "SSLServerCacheCloseTest test completed" << endl;
999 }
1000 #endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
1001
1002 /**
1003  * Verify Client Ciphers obtained using SSL MSG Callback.
1004  */
1005 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1006   EventBase eventBase;
1007   auto clientCtx = std::make_shared<SSLContext>();
1008   auto serverCtx = std::make_shared<SSLContext>();
1009   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1010   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1011   serverCtx->loadPrivateKey(kTestKey);
1012   serverCtx->loadCertificate(kTestCert);
1013   serverCtx->loadTrustedCertificates(kTestCA);
1014   serverCtx->loadClientCAList(kTestCA);
1015
1016   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1017   clientCtx->ciphers("AES256-SHA:AES128-SHA");
1018   clientCtx->loadPrivateKey(kTestKey);
1019   clientCtx->loadCertificate(kTestCert);
1020   clientCtx->loadTrustedCertificates(kTestCA);
1021
1022   int fds[2];
1023   getfds(fds);
1024
1025   AsyncSSLSocket::UniquePtr clientSock(
1026       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1027   AsyncSSLSocket::UniquePtr serverSock(
1028       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1029
1030   SSLHandshakeClient client(std::move(clientSock), true, true);
1031   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1032
1033   eventBase.loop();
1034
1035 #if defined(OPENSSL_IS_BORINGSSL)
1036   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
1037 #else
1038   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
1039 #endif
1040   EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1041   EXPECT_TRUE(client.handshakeVerify_);
1042   EXPECT_TRUE(client.handshakeSuccess_);
1043   EXPECT_TRUE(!client.handshakeError_);
1044   EXPECT_TRUE(server.handshakeVerify_);
1045   EXPECT_TRUE(server.handshakeSuccess_);
1046   EXPECT_TRUE(!server.handshakeError_);
1047 }
1048
1049 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1050   EventBase eventBase;
1051   auto ctx = std::make_shared<SSLContext>();
1052
1053   int fds[2];
1054   getfds(fds);
1055
1056   int bufLen = 42;
1057   uint8_t majorVersion = 18;
1058   uint8_t minorVersion = 25;
1059
1060   // Create callback buf
1061   auto buf = IOBuf::create(bufLen);
1062   buf->append(bufLen);
1063   folly::io::RWPrivateCursor cursor(buf.get());
1064   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1065   cursor.write<uint16_t>(0);
1066   cursor.write<uint8_t>(38);
1067   cursor.write<uint8_t>(majorVersion);
1068   cursor.write<uint8_t>(minorVersion);
1069   cursor.skip(32);
1070   cursor.write<uint32_t>(0);
1071
1072   SSL* ssl = ctx->createSSL();
1073   SCOPE_EXIT { SSL_free(ssl); };
1074   AsyncSSLSocket::UniquePtr sock(
1075       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1076   sock->enableClientHelloParsing();
1077
1078   // Test client hello parsing in one packet
1079   AsyncSSLSocket::clientHelloParsingCallback(
1080       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1081   buf.reset();
1082
1083   auto parsedClientHello = sock->getClientHelloInfo();
1084   EXPECT_TRUE(parsedClientHello != nullptr);
1085   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1086   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1087 }
1088
1089 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1090   EventBase eventBase;
1091   auto ctx = std::make_shared<SSLContext>();
1092
1093   int fds[2];
1094   getfds(fds);
1095
1096   int bufLen = 42;
1097   uint8_t majorVersion = 18;
1098   uint8_t minorVersion = 25;
1099
1100   // Create callback buf
1101   auto buf = IOBuf::create(bufLen);
1102   buf->append(bufLen);
1103   folly::io::RWPrivateCursor cursor(buf.get());
1104   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1105   cursor.write<uint16_t>(0);
1106   cursor.write<uint8_t>(38);
1107   cursor.write<uint8_t>(majorVersion);
1108   cursor.write<uint8_t>(minorVersion);
1109   cursor.skip(32);
1110   cursor.write<uint32_t>(0);
1111
1112   SSL* ssl = ctx->createSSL();
1113   SCOPE_EXIT { SSL_free(ssl); };
1114   AsyncSSLSocket::UniquePtr sock(
1115       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1116   sock->enableClientHelloParsing();
1117
1118   // Test parsing with two packets with first packet size < 3
1119   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1120   AsyncSSLSocket::clientHelloParsingCallback(
1121       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1122       ssl, sock.get());
1123   bufCopy.reset();
1124   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1125   AsyncSSLSocket::clientHelloParsingCallback(
1126       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1127       ssl, sock.get());
1128   bufCopy.reset();
1129
1130   auto parsedClientHello = sock->getClientHelloInfo();
1131   EXPECT_TRUE(parsedClientHello != nullptr);
1132   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1133   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1134 }
1135
1136 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1137   EventBase eventBase;
1138   auto ctx = std::make_shared<SSLContext>();
1139
1140   int fds[2];
1141   getfds(fds);
1142
1143   int bufLen = 42;
1144   uint8_t majorVersion = 18;
1145   uint8_t minorVersion = 25;
1146
1147   // Create callback buf
1148   auto buf = IOBuf::create(bufLen);
1149   buf->append(bufLen);
1150   folly::io::RWPrivateCursor cursor(buf.get());
1151   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1152   cursor.write<uint16_t>(0);
1153   cursor.write<uint8_t>(38);
1154   cursor.write<uint8_t>(majorVersion);
1155   cursor.write<uint8_t>(minorVersion);
1156   cursor.skip(32);
1157   cursor.write<uint32_t>(0);
1158
1159   SSL* ssl = ctx->createSSL();
1160   SCOPE_EXIT { SSL_free(ssl); };
1161   AsyncSSLSocket::UniquePtr sock(
1162       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1163   sock->enableClientHelloParsing();
1164
1165   // Test parsing with multiple small packets
1166   for (uint64_t i = 0; i < buf->length(); i += 3) {
1167     auto bufCopy = folly::IOBuf::copyBuffer(
1168         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1169     AsyncSSLSocket::clientHelloParsingCallback(
1170         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1171         ssl, sock.get());
1172     bufCopy.reset();
1173   }
1174
1175   auto parsedClientHello = sock->getClientHelloInfo();
1176   EXPECT_TRUE(parsedClientHello != nullptr);
1177   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1178   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1179 }
1180
1181 /**
1182  * Verify sucessful behavior of SSL certificate validation.
1183  */
1184 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1185   EventBase eventBase;
1186   auto clientCtx = std::make_shared<SSLContext>();
1187   auto dfServerCtx = std::make_shared<SSLContext>();
1188
1189   int fds[2];
1190   getfds(fds);
1191   getctx(clientCtx, dfServerCtx);
1192
1193   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1194   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1195
1196   AsyncSSLSocket::UniquePtr clientSock(
1197     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1198   AsyncSSLSocket::UniquePtr serverSock(
1199     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1200
1201   SSLHandshakeClient client(std::move(clientSock), true, true);
1202   clientCtx->loadTrustedCertificates(kTestCA);
1203
1204   SSLHandshakeServer server(std::move(serverSock), true, true);
1205
1206   eventBase.loop();
1207
1208   EXPECT_TRUE(client.handshakeVerify_);
1209   EXPECT_TRUE(client.handshakeSuccess_);
1210   EXPECT_TRUE(!client.handshakeError_);
1211   EXPECT_LE(0, client.handshakeTime.count());
1212   EXPECT_TRUE(!server.handshakeVerify_);
1213   EXPECT_TRUE(server.handshakeSuccess_);
1214   EXPECT_TRUE(!server.handshakeError_);
1215   EXPECT_LE(0, server.handshakeTime.count());
1216 }
1217
1218 /**
1219  * Verify that the client's verification callback is able to fail SSL
1220  * connection establishment.
1221  */
1222 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1223   EventBase eventBase;
1224   auto clientCtx = std::make_shared<SSLContext>();
1225   auto dfServerCtx = std::make_shared<SSLContext>();
1226
1227   int fds[2];
1228   getfds(fds);
1229   getctx(clientCtx, dfServerCtx);
1230
1231   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1232   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1233
1234   AsyncSSLSocket::UniquePtr clientSock(
1235     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1236   AsyncSSLSocket::UniquePtr serverSock(
1237     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1238
1239   SSLHandshakeClient client(std::move(clientSock), true, false);
1240   clientCtx->loadTrustedCertificates(kTestCA);
1241
1242   SSLHandshakeServer server(std::move(serverSock), true, true);
1243
1244   eventBase.loop();
1245
1246   EXPECT_TRUE(client.handshakeVerify_);
1247   EXPECT_TRUE(!client.handshakeSuccess_);
1248   EXPECT_TRUE(client.handshakeError_);
1249   EXPECT_LE(0, client.handshakeTime.count());
1250   EXPECT_TRUE(!server.handshakeVerify_);
1251   EXPECT_TRUE(!server.handshakeSuccess_);
1252   EXPECT_TRUE(server.handshakeError_);
1253   EXPECT_LE(0, server.handshakeTime.count());
1254 }
1255
1256 /**
1257  * Verify that the options in SSLContext can be overridden in
1258  * sslConnect/Accept.i.e specifying that no validation should be performed
1259  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1260  * the validation callback.
1261  */
1262 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1263   EventBase eventBase;
1264   auto clientCtx = std::make_shared<SSLContext>();
1265   auto dfServerCtx = std::make_shared<SSLContext>();
1266
1267   int fds[2];
1268   getfds(fds);
1269   getctx(clientCtx, dfServerCtx);
1270
1271   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1272   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1273
1274   AsyncSSLSocket::UniquePtr clientSock(
1275     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1276   AsyncSSLSocket::UniquePtr serverSock(
1277     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1278
1279   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1280   clientCtx->loadTrustedCertificates(kTestCA);
1281
1282   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1283
1284   eventBase.loop();
1285
1286   EXPECT_TRUE(!client.handshakeVerify_);
1287   EXPECT_TRUE(client.handshakeSuccess_);
1288   EXPECT_TRUE(!client.handshakeError_);
1289   EXPECT_LE(0, client.handshakeTime.count());
1290   EXPECT_TRUE(!server.handshakeVerify_);
1291   EXPECT_TRUE(server.handshakeSuccess_);
1292   EXPECT_TRUE(!server.handshakeError_);
1293   EXPECT_LE(0, server.handshakeTime.count());
1294 }
1295
1296 /**
1297  * Verify that the options in SSLContext can be overridden in
1298  * sslConnect/Accept. Enable verification even if context says otherwise.
1299  * Test requireClientCert with client cert
1300  */
1301 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1302   EventBase eventBase;
1303   auto clientCtx = std::make_shared<SSLContext>();
1304   auto serverCtx = std::make_shared<SSLContext>();
1305   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1306   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1307   serverCtx->loadPrivateKey(kTestKey);
1308   serverCtx->loadCertificate(kTestCert);
1309   serverCtx->loadTrustedCertificates(kTestCA);
1310   serverCtx->loadClientCAList(kTestCA);
1311
1312   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1313   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1314   clientCtx->loadPrivateKey(kTestKey);
1315   clientCtx->loadCertificate(kTestCert);
1316   clientCtx->loadTrustedCertificates(kTestCA);
1317
1318   int fds[2];
1319   getfds(fds);
1320
1321   AsyncSSLSocket::UniquePtr clientSock(
1322       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1323   AsyncSSLSocket::UniquePtr serverSock(
1324       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1325
1326   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1327   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1328
1329   eventBase.loop();
1330
1331   EXPECT_TRUE(client.handshakeVerify_);
1332   EXPECT_TRUE(client.handshakeSuccess_);
1333   EXPECT_FALSE(client.handshakeError_);
1334   EXPECT_LE(0, client.handshakeTime.count());
1335   EXPECT_TRUE(server.handshakeVerify_);
1336   EXPECT_TRUE(server.handshakeSuccess_);
1337   EXPECT_FALSE(server.handshakeError_);
1338   EXPECT_LE(0, server.handshakeTime.count());
1339 }
1340
1341 /**
1342  * Verify that the client's verification callback is able to override
1343  * the preverification failure and allow a successful connection.
1344  */
1345 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1346   EventBase eventBase;
1347   auto clientCtx = std::make_shared<SSLContext>();
1348   auto dfServerCtx = std::make_shared<SSLContext>();
1349
1350   int fds[2];
1351   getfds(fds);
1352   getctx(clientCtx, dfServerCtx);
1353
1354   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1355   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1356
1357   AsyncSSLSocket::UniquePtr clientSock(
1358     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1359   AsyncSSLSocket::UniquePtr serverSock(
1360     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1361
1362   SSLHandshakeClient client(std::move(clientSock), false, true);
1363   SSLHandshakeServer server(std::move(serverSock), true, true);
1364
1365   eventBase.loop();
1366
1367   EXPECT_TRUE(client.handshakeVerify_);
1368   EXPECT_TRUE(client.handshakeSuccess_);
1369   EXPECT_TRUE(!client.handshakeError_);
1370   EXPECT_LE(0, client.handshakeTime.count());
1371   EXPECT_TRUE(!server.handshakeVerify_);
1372   EXPECT_TRUE(server.handshakeSuccess_);
1373   EXPECT_TRUE(!server.handshakeError_);
1374   EXPECT_LE(0, server.handshakeTime.count());
1375 }
1376
1377 /**
1378  * Verify that specifying that no validation should be performed allows an
1379  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1380  * callback.
1381  */
1382 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1383   EventBase eventBase;
1384   auto clientCtx = std::make_shared<SSLContext>();
1385   auto dfServerCtx = std::make_shared<SSLContext>();
1386
1387   int fds[2];
1388   getfds(fds);
1389   getctx(clientCtx, dfServerCtx);
1390
1391   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1392   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1393
1394   AsyncSSLSocket::UniquePtr clientSock(
1395     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1396   AsyncSSLSocket::UniquePtr serverSock(
1397     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1398
1399   SSLHandshakeClient client(std::move(clientSock), false, false);
1400   SSLHandshakeServer server(std::move(serverSock), false, false);
1401
1402   eventBase.loop();
1403
1404   EXPECT_TRUE(!client.handshakeVerify_);
1405   EXPECT_TRUE(client.handshakeSuccess_);
1406   EXPECT_TRUE(!client.handshakeError_);
1407   EXPECT_LE(0, client.handshakeTime.count());
1408   EXPECT_TRUE(!server.handshakeVerify_);
1409   EXPECT_TRUE(server.handshakeSuccess_);
1410   EXPECT_TRUE(!server.handshakeError_);
1411   EXPECT_LE(0, server.handshakeTime.count());
1412 }
1413
1414 /**
1415  * Test requireClientCert with client cert
1416  */
1417 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1418   EventBase eventBase;
1419   auto clientCtx = std::make_shared<SSLContext>();
1420   auto serverCtx = std::make_shared<SSLContext>();
1421   serverCtx->setVerificationOption(
1422       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1423   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1424   serverCtx->loadPrivateKey(kTestKey);
1425   serverCtx->loadCertificate(kTestCert);
1426   serverCtx->loadTrustedCertificates(kTestCA);
1427   serverCtx->loadClientCAList(kTestCA);
1428
1429   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1430   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1431   clientCtx->loadPrivateKey(kTestKey);
1432   clientCtx->loadCertificate(kTestCert);
1433   clientCtx->loadTrustedCertificates(kTestCA);
1434
1435   int fds[2];
1436   getfds(fds);
1437
1438   AsyncSSLSocket::UniquePtr clientSock(
1439       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1440   AsyncSSLSocket::UniquePtr serverSock(
1441       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1442
1443   SSLHandshakeClient client(std::move(clientSock), true, true);
1444   SSLHandshakeServer server(std::move(serverSock), true, true);
1445
1446   eventBase.loop();
1447
1448   EXPECT_TRUE(client.handshakeVerify_);
1449   EXPECT_TRUE(client.handshakeSuccess_);
1450   EXPECT_FALSE(client.handshakeError_);
1451   EXPECT_LE(0, client.handshakeTime.count());
1452   EXPECT_TRUE(server.handshakeVerify_);
1453   EXPECT_TRUE(server.handshakeSuccess_);
1454   EXPECT_FALSE(server.handshakeError_);
1455   EXPECT_LE(0, server.handshakeTime.count());
1456 }
1457
1458
1459 /**
1460  * Test requireClientCert with no client cert
1461  */
1462 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1463   EventBase eventBase;
1464   auto clientCtx = std::make_shared<SSLContext>();
1465   auto serverCtx = std::make_shared<SSLContext>();
1466   serverCtx->setVerificationOption(
1467       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1468   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1469   serverCtx->loadPrivateKey(kTestKey);
1470   serverCtx->loadCertificate(kTestCert);
1471   serverCtx->loadTrustedCertificates(kTestCA);
1472   serverCtx->loadClientCAList(kTestCA);
1473   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1474   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1475
1476   int fds[2];
1477   getfds(fds);
1478
1479   AsyncSSLSocket::UniquePtr clientSock(
1480       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1481   AsyncSSLSocket::UniquePtr serverSock(
1482       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1483
1484   SSLHandshakeClient client(std::move(clientSock), false, false);
1485   SSLHandshakeServer server(std::move(serverSock), false, false);
1486
1487   eventBase.loop();
1488
1489   EXPECT_FALSE(server.handshakeVerify_);
1490   EXPECT_FALSE(server.handshakeSuccess_);
1491   EXPECT_TRUE(server.handshakeError_);
1492   EXPECT_LE(0, client.handshakeTime.count());
1493   EXPECT_LE(0, server.handshakeTime.count());
1494 }
1495
1496 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1497   auto cert = getFileAsBuf(kTestCert);
1498   auto key = getFileAsBuf(kTestKey);
1499
1500   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1501   BIO_write(certBio.get(), cert.data(), cert.size());
1502   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1503   BIO_write(keyBio.get(), key.data(), key.size());
1504
1505   // Create SSL structs from buffers to get properties
1506   ssl::X509UniquePtr certStruct(
1507       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1508   ssl::EvpPkeyUniquePtr keyStruct(
1509       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1510   certBio = nullptr;
1511   keyBio = nullptr;
1512
1513   auto origCommonName = getCommonName(certStruct.get());
1514   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1515   certStruct = nullptr;
1516   keyStruct = nullptr;
1517
1518   auto ctx = std::make_shared<SSLContext>();
1519   ctx->loadPrivateKeyFromBufferPEM(key);
1520   ctx->loadCertificateFromBufferPEM(cert);
1521   ctx->loadTrustedCertificates(kTestCA);
1522
1523   ssl::SSLUniquePtr ssl(ctx->createSSL());
1524
1525   auto newCert = SSL_get_certificate(ssl.get());
1526   auto newKey = SSL_get_privatekey(ssl.get());
1527
1528   // Get properties from SSL struct
1529   auto newCommonName = getCommonName(newCert);
1530   auto newKeySize = EVP_PKEY_bits(newKey);
1531
1532   // Check that the key and cert have the expected properties
1533   EXPECT_EQ(origCommonName, newCommonName);
1534   EXPECT_EQ(origKeySize, newKeySize);
1535 }
1536
1537 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1538   EventBase eb;
1539
1540   // Set up SSL context.
1541   auto sslContext = std::make_shared<SSLContext>();
1542   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1543
1544   // create SSL socket
1545   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1546
1547   EXPECT_EQ(1500, socket->getMinWriteSize());
1548
1549   socket->setMinWriteSize(0);
1550   EXPECT_EQ(0, socket->getMinWriteSize());
1551   socket->setMinWriteSize(50000);
1552   EXPECT_EQ(50000, socket->getMinWriteSize());
1553 }
1554
1555 class ReadCallbackTerminator : public ReadCallback {
1556  public:
1557   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1558       : ReadCallback(wcb)
1559       , base_(base) {}
1560
1561   // Do not write data back, terminate the loop.
1562   void readDataAvailable(size_t len) noexcept override {
1563     std::cerr << "readDataAvailable, len " << len << std::endl;
1564
1565     currentBuffer.length = len;
1566
1567     buffers.push_back(currentBuffer);
1568     currentBuffer.reset();
1569     state = STATE_SUCCEEDED;
1570
1571     socket_->setReadCB(nullptr);
1572     base_->terminateLoopSoon();
1573   }
1574  private:
1575   EventBase* base_;
1576 };
1577
1578
1579 /**
1580  * Test a full unencrypted codepath
1581  */
1582 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1583   EventBase base;
1584
1585   auto clientCtx = std::make_shared<folly::SSLContext>();
1586   auto serverCtx = std::make_shared<folly::SSLContext>();
1587   int fds[2];
1588   getfds(fds);
1589   getctx(clientCtx, serverCtx);
1590   auto client = AsyncSSLSocket::newSocket(
1591                   clientCtx, &base, fds[0], false, true);
1592   auto server = AsyncSSLSocket::newSocket(
1593                   serverCtx, &base, fds[1], true, true);
1594
1595   ReadCallbackTerminator readCallback(&base, nullptr);
1596   server->setReadCB(&readCallback);
1597   readCallback.setSocket(server);
1598
1599   uint8_t buf[128];
1600   memset(buf, 'a', sizeof(buf));
1601   client->write(nullptr, buf, sizeof(buf));
1602
1603   // Check that bytes are unencrypted
1604   char c;
1605   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1606   EXPECT_EQ('a', c);
1607
1608   EventBaseAborter eba(&base, 3000);
1609   base.loop();
1610
1611   EXPECT_EQ(1, readCallback.buffers.size());
1612   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1613
1614   server->setReadCB(&readCallback);
1615
1616   // Unencrypted
1617   server->sslAccept(nullptr);
1618   client->sslConn(nullptr);
1619
1620   // Do NOT wait for handshake, writing should be queued and happen after
1621
1622   client->write(nullptr, buf, sizeof(buf));
1623
1624   // Check that bytes are *not* unencrypted
1625   char c2;
1626   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1627   EXPECT_NE('a', c2);
1628
1629
1630   base.loop();
1631
1632   EXPECT_EQ(2, readCallback.buffers.size());
1633   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1634 }
1635
1636 TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) {
1637   auto clientCtx = std::make_shared<folly::SSLContext>();
1638   auto serverCtx = std::make_shared<folly::SSLContext>();
1639   getctx(clientCtx, serverCtx);
1640
1641   WriteCallbackBase writeCallback;
1642   ReadCallback readCallback(&writeCallback);
1643   HandshakeCallback handshakeCallback(&readCallback);
1644   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1645   TestSSLServer server(&acceptCallback);
1646
1647   EventBase evb;
1648   std::shared_ptr<AsyncSSLSocket> socket =
1649       AsyncSSLSocket::newSocket(clientCtx, &evb, true);
1650   socket->connect(nullptr, server.getAddress(), 0);
1651
1652   evb.loop();
1653
1654   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, socket->getSSLState());
1655   socket->sslConn(nullptr);
1656   evb.loop();
1657   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, socket->getSSLState());
1658
1659   // write()
1660   std::array<uint8_t, 128> buf;
1661   memset(buf.data(), 'a', buf.size());
1662   socket->write(nullptr, buf.data(), buf.size());
1663
1664   socket->close();
1665 }
1666
1667 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1668   // Start listening on a local port
1669   WriteCallbackBase writeCallback;
1670   WriteErrorCallback readCallback(&writeCallback);
1671   HandshakeCallback handshakeCallback(&readCallback,
1672                                       HandshakeCallback::EXPECT_ERROR);
1673   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1674   TestSSLServer server(&acceptCallback);
1675
1676   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1677   socket->open();
1678   uint8_t buf[3] = {0x16, 0x03, 0x01};
1679   socket->write(buf, sizeof(buf));
1680   socket->closeWithReset();
1681
1682   handshakeCallback.waitForHandshake();
1683   EXPECT_NE(
1684       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1685   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1686 }
1687
1688 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1689   // Start listening on a local port
1690   WriteCallbackBase writeCallback;
1691   WriteErrorCallback readCallback(&writeCallback);
1692   HandshakeCallback handshakeCallback(&readCallback,
1693                                       HandshakeCallback::EXPECT_ERROR);
1694   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1695   TestSSLServer server(&acceptCallback);
1696
1697   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1698   socket->open();
1699   uint8_t buf[3] = {0x16, 0x03, 0x01};
1700   socket->write(buf, sizeof(buf));
1701   socket->close();
1702
1703   handshakeCallback.waitForHandshake();
1704 #if FOLLY_OPENSSL_IS_110
1705   EXPECT_NE(
1706       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1707 #else
1708   EXPECT_NE(
1709       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1710 #endif
1711 }
1712
1713 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1714   // Start listening on a local port
1715   WriteCallbackBase writeCallback;
1716   WriteErrorCallback readCallback(&writeCallback);
1717   HandshakeCallback handshakeCallback(&readCallback,
1718                                       HandshakeCallback::EXPECT_ERROR);
1719   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1720   TestSSLServer server(&acceptCallback);
1721
1722   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1723   socket->open();
1724   uint8_t buf[256] = {0x16, 0x03};
1725   memset(buf + 2, 'a', sizeof(buf) - 2);
1726   socket->write(buf, sizeof(buf));
1727   socket->close();
1728
1729   handshakeCallback.waitForHandshake();
1730   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1731             std::string::npos);
1732 #if defined(OPENSSL_IS_BORINGSSL)
1733   EXPECT_NE(
1734       handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
1735       std::string::npos);
1736 #elif FOLLY_OPENSSL_IS_110
1737   EXPECT_NE(handshakeCallback.errorString_.find("packet length too long"),
1738             std::string::npos);
1739 #else
1740   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1741             std::string::npos);
1742 #endif
1743 }
1744
1745 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
1746   using folly::ssl::OpenSSLUtils;
1747   EXPECT_EQ(
1748       OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
1749   // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
1750   EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
1751   // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
1752   EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
1753 }
1754
1755 #if FOLLY_ALLOW_TFO
1756
1757 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1758  public:
1759   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1760
1761   explicit MockAsyncTFOSSLSocket(
1762       std::shared_ptr<folly::SSLContext> sslCtx,
1763       EventBase* evb)
1764       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1765
1766   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1767 };
1768
1769 /**
1770  * Test connecting to, writing to, reading from, and closing the
1771  * connection to the SSL server with TFO.
1772  */
1773 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1774   // Start listening on a local port
1775   WriteCallbackBase writeCallback;
1776   ReadCallback readCallback(&writeCallback);
1777   HandshakeCallback handshakeCallback(&readCallback);
1778   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1779   TestSSLServer server(&acceptCallback, true);
1780
1781   // Set up SSL context.
1782   auto sslContext = std::make_shared<SSLContext>();
1783
1784   // connect
1785   auto socket =
1786       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1787   socket->enableTFO();
1788   socket->open();
1789
1790   // write()
1791   std::array<uint8_t, 128> buf;
1792   memset(buf.data(), 'a', buf.size());
1793   socket->write(buf.data(), buf.size());
1794
1795   // read()
1796   std::array<uint8_t, 128> readbuf;
1797   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1798   EXPECT_EQ(bytesRead, 128);
1799   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1800
1801   // close()
1802   socket->close();
1803 }
1804
1805 /**
1806  * Test connecting to, writing to, reading from, and closing the
1807  * connection to the SSL server with TFO.
1808  */
1809 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1810   // Start listening on a local port
1811   WriteCallbackBase writeCallback;
1812   ReadCallback readCallback(&writeCallback);
1813   HandshakeCallback handshakeCallback(&readCallback);
1814   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1815   TestSSLServer server(&acceptCallback, false);
1816
1817   // Set up SSL context.
1818   auto sslContext = std::make_shared<SSLContext>();
1819
1820   // connect
1821   auto socket =
1822       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1823   socket->enableTFO();
1824   socket->open();
1825
1826   // write()
1827   std::array<uint8_t, 128> buf;
1828   memset(buf.data(), 'a', buf.size());
1829   socket->write(buf.data(), buf.size());
1830
1831   // read()
1832   std::array<uint8_t, 128> readbuf;
1833   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1834   EXPECT_EQ(bytesRead, 128);
1835   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1836
1837   // close()
1838   socket->close();
1839 }
1840
1841 class ConnCallback : public AsyncSocket::ConnectCallback {
1842  public:
1843   void connectSuccess() noexcept override {
1844     state = State::SUCCESS;
1845   }
1846
1847   void connectErr(const AsyncSocketException& ex) noexcept override {
1848     state = State::ERROR;
1849     error = ex.what();
1850   }
1851
1852   enum class State { WAITING, SUCCESS, ERROR };
1853
1854   State state{State::WAITING};
1855   std::string error;
1856 };
1857
1858 template <class Cardinality>
1859 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1860     EventBase* evb,
1861     const SocketAddress& address,
1862     Cardinality cardinality) {
1863   // Set up SSL context.
1864   auto sslContext = std::make_shared<SSLContext>();
1865
1866   // connect
1867   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1868       new MockAsyncTFOSSLSocket(sslContext, evb));
1869   socket->enableTFO();
1870
1871   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1872       .Times(cardinality)
1873       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1874         sockaddr_storage addr;
1875         auto len = address.getAddress(&addr);
1876         return connect(fd, (const struct sockaddr*)&addr, len);
1877       }));
1878   return socket;
1879 }
1880
1881 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1882   // Start listening on a local port
1883   WriteCallbackBase writeCallback;
1884   ReadCallback readCallback(&writeCallback);
1885   HandshakeCallback handshakeCallback(&readCallback);
1886   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1887   TestSSLServer server(&acceptCallback, true);
1888
1889   EventBase evb;
1890
1891   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1892   ConnCallback ccb;
1893   socket->connect(&ccb, server.getAddress(), 30);
1894
1895   evb.loop();
1896   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1897
1898   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1899   evb.loop();
1900
1901   BlockingSocket sock(std::move(socket));
1902   // write()
1903   std::array<uint8_t, 128> buf;
1904   memset(buf.data(), 'a', buf.size());
1905   sock.write(buf.data(), buf.size());
1906
1907   // read()
1908   std::array<uint8_t, 128> readbuf;
1909   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1910   EXPECT_EQ(bytesRead, 128);
1911   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1912
1913   // close()
1914   sock.close();
1915 }
1916
1917 #if !defined(OPENSSL_IS_BORINGSSL)
1918 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1919   // Start listening on a local port
1920   ConnectTimeoutCallback acceptCallback;
1921   TestSSLServer server(&acceptCallback, true);
1922
1923   // Set up SSL context.
1924   auto sslContext = std::make_shared<SSLContext>();
1925
1926   // connect
1927   auto socket =
1928       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1929   socket->enableTFO();
1930   EXPECT_THROW(
1931       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
1932 }
1933 #endif
1934
1935 #if !defined(OPENSSL_IS_BORINGSSL)
1936 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1937   // Start listening on a local port
1938   ConnectTimeoutCallback acceptCallback;
1939   TestSSLServer server(&acceptCallback, true);
1940
1941   EventBase evb;
1942
1943   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1944   ConnCallback ccb;
1945   // Set a short timeout
1946   socket->connect(&ccb, server.getAddress(), 1);
1947
1948   evb.loop();
1949   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1950 }
1951 #endif
1952
1953 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
1954   // Start listening on a local port
1955   EmptyReadCallback readCallback;
1956   HandshakeCallback handshakeCallback(
1957       &readCallback, HandshakeCallback::EXPECT_ERROR);
1958   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
1959   TestSSLServer server(&acceptCallback, true);
1960
1961   EventBase evb;
1962
1963   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1964   ConnCallback ccb;
1965   socket->connect(&ccb, server.getAddress(), 100);
1966
1967   evb.loop();
1968   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1969   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
1970 }
1971
1972 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
1973   // Start listening on a local port
1974   EventBase evb;
1975
1976   // Hopefully nothing is listening on this address
1977   SocketAddress addr("127.0.0.1", 65535);
1978   auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
1979   ConnCallback ccb;
1980   socket->connect(&ccb, addr, 100);
1981
1982   evb.loop();
1983   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1984   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
1985 }
1986
1987 TEST(AsyncSSLSocketTest, TestPreReceivedData) {
1988   EventBase clientEventBase;
1989   EventBase serverEventBase;
1990   auto clientCtx = std::make_shared<SSLContext>();
1991   auto dfServerCtx = std::make_shared<SSLContext>();
1992   std::array<int, 2> fds;
1993   getfds(fds.data());
1994   getctx(clientCtx, dfServerCtx);
1995
1996   AsyncSSLSocket::UniquePtr clientSockPtr(
1997       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
1998   AsyncSSLSocket::UniquePtr serverSockPtr(
1999       new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
2000   auto clientSock = clientSockPtr.get();
2001   auto serverSock = serverSockPtr.get();
2002   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
2003
2004   // Steal some data from the server.
2005   clientEventBase.loopOnce();
2006   std::array<uint8_t, 10> buf;
2007   recv(fds[1], buf.data(), buf.size(), 0);
2008
2009   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
2010   SSLHandshakeServer server(std::move(serverSockPtr), true, true);
2011   while (!client.handshakeSuccess_ && !client.handshakeError_) {
2012     serverEventBase.loopOnce();
2013     clientEventBase.loopOnce();
2014   }
2015
2016   EXPECT_TRUE(client.handshakeSuccess_);
2017   EXPECT_TRUE(server.handshakeSuccess_);
2018   EXPECT_EQ(
2019       serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
2020 }
2021
2022 TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) {
2023   EventBase clientEventBase;
2024   EventBase serverEventBase;
2025   auto clientCtx = std::make_shared<SSLContext>();
2026   auto dfServerCtx = std::make_shared<SSLContext>();
2027   std::array<int, 2> fds;
2028   getfds(fds.data());
2029   getctx(clientCtx, dfServerCtx);
2030
2031   AsyncSSLSocket::UniquePtr clientSockPtr(
2032       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
2033   AsyncSocket::UniquePtr serverSockPtr(
2034       new AsyncSocket(&serverEventBase, fds[1]));
2035   auto clientSock = clientSockPtr.get();
2036   auto serverSock = serverSockPtr.get();
2037   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
2038
2039   // Steal some data from the server.
2040   clientEventBase.loopOnce();
2041   std::array<uint8_t, 10> buf;
2042   recv(fds[1], buf.data(), buf.size(), 0);
2043   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
2044   AsyncSSLSocket::UniquePtr serverSSLSockPtr(
2045       new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true));
2046   auto serverSSLSock = serverSSLSockPtr.get();
2047   SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true);
2048   while (!client.handshakeSuccess_ && !client.handshakeError_) {
2049     serverEventBase.loopOnce();
2050     clientEventBase.loopOnce();
2051   }
2052
2053   EXPECT_TRUE(client.handshakeSuccess_);
2054   EXPECT_TRUE(server.handshakeSuccess_);
2055   EXPECT_EQ(
2056       serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
2057 }
2058
2059 /**
2060  * Test overriding the flags passed to "sendmsg()" system call,
2061  * and verifying that write requests fail properly.
2062  */
2063 TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
2064   // Start listening on a local port
2065   SendMsgFlagsCallback msgCallback;
2066   ExpectWriteErrorCallback writeCallback(&msgCallback);
2067   ReadCallback readCallback(&writeCallback);
2068   HandshakeCallback handshakeCallback(&readCallback);
2069   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2070   TestSSLServer server(&acceptCallback);
2071
2072   // Set up SSL context.
2073   auto sslContext = std::make_shared<SSLContext>();
2074   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2075
2076   // connect
2077   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
2078                                                  sslContext);
2079   socket->open();
2080
2081   // Setting flags to "-1" to trigger "Invalid argument" error
2082   // on attempt to use this flags in sendmsg() system call.
2083   msgCallback.resetFlags(-1);
2084
2085   // write()
2086   std::vector<uint8_t> buf(128, 'a');
2087   ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
2088
2089   // close()
2090   socket->close();
2091
2092   cerr << "SendMsgParamsCallback test completed" << endl;
2093 }
2094
2095 #ifdef MSG_ERRQUEUE
2096 /**
2097  * Test connecting to, writing to, reading from, and closing the
2098  * connection to the SSL server.
2099  */
2100 TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
2101   // This test requires Linux kernel v4.6 or later
2102   struct utsname s_uname;
2103   memset(&s_uname, 0, sizeof(s_uname));
2104   ASSERT_EQ(uname(&s_uname), 0);
2105   int major, minor;
2106   folly::StringPiece extra;
2107   if (folly::split<false>(
2108         '.', std::string(s_uname.release) + ".", major, minor, extra)) {
2109     if (major < 4 || (major == 4 && minor < 6)) {
2110       LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
2111                 << "kernel ver. " << s_uname.release << " detected).";
2112       return;
2113     }
2114   }
2115
2116   // Start listening on a local port
2117   SendMsgDataCallback msgCallback;
2118   WriteCheckTimestampCallback writeCallback(&msgCallback);
2119   ReadCallback readCallback(&writeCallback);
2120   HandshakeCallback handshakeCallback(&readCallback);
2121   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2122   TestSSLServer server(&acceptCallback);
2123
2124   // Set up SSL context.
2125   auto sslContext = std::make_shared<SSLContext>();
2126   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2127
2128   // connect
2129   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
2130                                                  sslContext);
2131   socket->open();
2132
2133   // Adding MSG_EOR flag to the message flags - it'll trigger
2134   // timestamp generation for the last byte of the message.
2135   msgCallback.resetFlags(MSG_DONTWAIT|MSG_NOSIGNAL|MSG_EOR);
2136
2137   // Init ancillary data buffer to trigger timestamp notification
2138   union {
2139     uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
2140     struct cmsghdr cmsg;
2141   } u;
2142   u.cmsg.cmsg_level = SOL_SOCKET;
2143   u.cmsg.cmsg_type = SO_TIMESTAMPING;
2144   u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
2145   uint32_t flags =
2146       SOF_TIMESTAMPING_TX_SCHED |
2147       SOF_TIMESTAMPING_TX_SOFTWARE |
2148       SOF_TIMESTAMPING_TX_ACK;
2149   memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
2150   std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
2151   memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
2152   msgCallback.resetData(std::move(ctrl));
2153
2154   // write()
2155   std::vector<uint8_t> buf(128, 'a');
2156   socket->write(buf.data(), buf.size());
2157
2158   // read()
2159   std::vector<uint8_t> readbuf(buf.size());
2160   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2161   EXPECT_EQ(bytesRead, buf.size());
2162   EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
2163
2164   writeCallback.checkForTimestampNotifications();
2165
2166   // close()
2167   socket->close();
2168
2169   cerr << "SendMsgDataCallback test completed" << endl;
2170 }
2171 #endif // MSG_ERRQUEUE
2172
2173 #endif
2174
2175 } // namespace
2176
2177 #ifdef SIGPIPE
2178 ///////////////////////////////////////////////////////////////////////////
2179 // init_unit_test_suite
2180 ///////////////////////////////////////////////////////////////////////////
2181 namespace {
2182 struct Initializer {
2183   Initializer() {
2184     signal(SIGPIPE, SIG_IGN);
2185   }
2186 };
2187 Initializer initializer;
2188 } // anonymous
2189 #endif