a8e4e4482bdf96dda536dbe5a460cc4013f3520c
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <folly/SocketAddress.h>
19 #include <folly/io/Cursor.h>
20 #include <folly/io/async/AsyncSSLSocket.h>
21 #include <folly/io/async/EventBase.h>
22 #include <folly/portability/GMock.h>
23 #include <folly/portability/GTest.h>
24 #include <folly/portability/OpenSSL.h>
25 #include <folly/portability/Sockets.h>
26 #include <folly/portability/Unistd.h>
27
28 #include <folly/io/async/test/BlockingSocket.h>
29
30 #include <fcntl.h>
31 #include <signal.h>
32 #include <sys/types.h>
33
34 #include <fstream>
35 #include <iostream>
36 #include <list>
37 #include <set>
38 #include <thread>
39
40 #ifdef MSG_ERRQUEUE
41 #include <sys/utsname.h>
42 #endif
43
44 using std::string;
45 using std::vector;
46 using std::min;
47 using std::cerr;
48 using std::endl;
49 using std::list;
50
51 using namespace testing;
52
53 namespace folly {
54 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
55 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
56 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
57
58 constexpr size_t SSLClient::kMaxReadBufferSz;
59 constexpr size_t SSLClient::kMaxReadsPerEvent;
60
61 void getfds(int fds[2]) {
62   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
63     FAIL() << "failed to create socketpair: " << strerror(errno);
64   }
65   for (int idx = 0; idx < 2; ++idx) {
66     int flags = fcntl(fds[idx], F_GETFL, 0);
67     if (flags == -1) {
68       FAIL() << "failed to get flags for socket " << idx << ": "
69              << strerror(errno);
70     }
71     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
72       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
73              << strerror(errno);
74     }
75   }
76 }
77
78 void getctx(
79   std::shared_ptr<folly::SSLContext> clientCtx,
80   std::shared_ptr<folly::SSLContext> serverCtx) {
81   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
82
83   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
84   serverCtx->loadCertificate(kTestCert);
85   serverCtx->loadPrivateKey(kTestKey);
86 }
87
88 void sslsocketpair(
89   EventBase* eventBase,
90   AsyncSSLSocket::UniquePtr* clientSock,
91   AsyncSSLSocket::UniquePtr* serverSock) {
92   auto clientCtx = std::make_shared<folly::SSLContext>();
93   auto serverCtx = std::make_shared<folly::SSLContext>();
94   int fds[2];
95   getfds(fds);
96   getctx(clientCtx, serverCtx);
97   clientSock->reset(new AsyncSSLSocket(
98                       clientCtx, eventBase, fds[0], false));
99   serverSock->reset(new AsyncSSLSocket(
100                       serverCtx, eventBase, fds[1], true));
101
102   // (*clientSock)->setSendTimeout(100);
103   // (*serverSock)->setSendTimeout(100);
104 }
105
106 // client protocol filters
107 bool clientProtoFilterPickPony(unsigned char** client,
108   unsigned int* client_len, const unsigned char*, unsigned int ) {
109   //the protocol string in length prefixed byte string. the
110   //length byte is not included in the length
111   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
112   *client = p;
113   *client_len = 7;
114   return true;
115 }
116
117 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
118   const unsigned char*, unsigned int) {
119   return false;
120 }
121
122 std::string getFileAsBuf(const char* fileName) {
123   std::string buffer;
124   folly::readFile(fileName, buffer);
125   return buffer;
126 }
127
128 std::string getCommonName(X509* cert) {
129   X509_NAME* subject = X509_get_subject_name(cert);
130   std::string cn;
131   cn.resize(ub_common_name);
132   X509_NAME_get_text_by_NID(
133       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
134   return cn;
135 }
136
137 /**
138  * Test connecting to, writing to, reading from, and closing the
139  * connection to the SSL server.
140  */
141 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
142   // Start listening on a local port
143   WriteCallbackBase writeCallback;
144   ReadCallback readCallback(&writeCallback);
145   HandshakeCallback handshakeCallback(&readCallback);
146   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
147   TestSSLServer server(&acceptCallback);
148
149   // Set up SSL context.
150   std::shared_ptr<SSLContext> sslContext(new SSLContext());
151   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
152   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
153   //sslContext->authenticate(true, false);
154
155   // connect
156   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
157                                                  sslContext);
158   socket->open(std::chrono::milliseconds(10000));
159
160   // write()
161   uint8_t buf[128];
162   memset(buf, 'a', sizeof(buf));
163   socket->write(buf, sizeof(buf));
164
165   // read()
166   uint8_t readbuf[128];
167   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
168   EXPECT_EQ(bytesRead, 128);
169   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
170
171   // close()
172   socket->close();
173
174   cerr << "ConnectWriteReadClose test completed" << endl;
175   EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
176 }
177
178 /**
179  * Test reading after server close.
180  */
181 TEST(AsyncSSLSocketTest, ReadAfterClose) {
182   // Start listening on a local port
183   WriteCallbackBase writeCallback;
184   ReadEOFCallback readCallback(&writeCallback);
185   HandshakeCallback handshakeCallback(&readCallback);
186   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
187   auto server = std::make_unique<TestSSLServer>(&acceptCallback);
188
189   // Set up SSL context.
190   auto sslContext = std::make_shared<SSLContext>();
191   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
192
193   auto socket =
194       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
195   socket->open();
196
197   // This should trigger an EOF on the client.
198   auto evb = handshakeCallback.getSocket()->getEventBase();
199   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
200   std::array<uint8_t, 128> readbuf;
201   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
202   EXPECT_EQ(0, bytesRead);
203 }
204
205 /**
206  * Test bad renegotiation
207  */
208 #if !defined(OPENSSL_IS_BORINGSSL)
209 TEST(AsyncSSLSocketTest, Renegotiate) {
210   EventBase eventBase;
211   auto clientCtx = std::make_shared<SSLContext>();
212   auto dfServerCtx = std::make_shared<SSLContext>();
213   std::array<int, 2> fds;
214   getfds(fds.data());
215   getctx(clientCtx, dfServerCtx);
216
217   AsyncSSLSocket::UniquePtr clientSock(
218       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
219   AsyncSSLSocket::UniquePtr serverSock(
220       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
221   SSLHandshakeClient client(std::move(clientSock), true, true);
222   RenegotiatingServer server(std::move(serverSock));
223
224   while (!client.handshakeSuccess_ && !client.handshakeError_) {
225     eventBase.loopOnce();
226   }
227
228   ASSERT_TRUE(client.handshakeSuccess_);
229
230   auto sslSock = std::move(client).moveSocket();
231   sslSock->detachEventBase();
232   // This is nasty, however we don't want to add support for
233   // renegotiation in AsyncSSLSocket.
234   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
235
236   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
237
238   std::thread t([&]() { eventBase.loopForever(); });
239
240   // Trigger the renegotiation.
241   std::array<uint8_t, 128> buf;
242   memset(buf.data(), 'a', buf.size());
243   try {
244     socket->write(buf.data(), buf.size());
245   } catch (AsyncSocketException& e) {
246     LOG(INFO) << "client got error " << e.what();
247   }
248   eventBase.terminateLoopSoon();
249   t.join();
250
251   eventBase.loop();
252   ASSERT_TRUE(server.renegotiationError_);
253 }
254 #endif
255
256 /**
257  * Negative test for handshakeError().
258  */
259 TEST(AsyncSSLSocketTest, HandshakeError) {
260   // Start listening on a local port
261   WriteCallbackBase writeCallback;
262   WriteErrorCallback readCallback(&writeCallback);
263   HandshakeCallback handshakeCallback(&readCallback);
264   HandshakeErrorCallback acceptCallback(&handshakeCallback);
265   TestSSLServer server(&acceptCallback);
266
267   // Set up SSL context.
268   std::shared_ptr<SSLContext> sslContext(new SSLContext());
269   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
270
271   // connect
272   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
273                                                  sslContext);
274   // read()
275   bool ex = false;
276   try {
277     socket->open();
278
279     uint8_t readbuf[128];
280     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
281     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
282   } catch (AsyncSocketException&) {
283     ex = true;
284   }
285   EXPECT_TRUE(ex);
286
287   // close()
288   socket->close();
289   cerr << "HandshakeError test completed" << endl;
290 }
291
292 /**
293  * Negative test for readError().
294  */
295 TEST(AsyncSSLSocketTest, ReadError) {
296   // Start listening on a local port
297   WriteCallbackBase writeCallback;
298   ReadErrorCallback readCallback(&writeCallback);
299   HandshakeCallback handshakeCallback(&readCallback);
300   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
301   TestSSLServer server(&acceptCallback);
302
303   // Set up SSL context.
304   std::shared_ptr<SSLContext> sslContext(new SSLContext());
305   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
306
307   // connect
308   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
309                                                  sslContext);
310   socket->open();
311
312   // write something to trigger ssl handshake
313   uint8_t buf[128];
314   memset(buf, 'a', sizeof(buf));
315   socket->write(buf, sizeof(buf));
316
317   socket->close();
318   cerr << "ReadError test completed" << endl;
319 }
320
321 /**
322  * Negative test for writeError().
323  */
324 TEST(AsyncSSLSocketTest, WriteError) {
325   // Start listening on a local port
326   WriteCallbackBase writeCallback;
327   WriteErrorCallback readCallback(&writeCallback);
328   HandshakeCallback handshakeCallback(&readCallback);
329   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
330   TestSSLServer server(&acceptCallback);
331
332   // Set up SSL context.
333   std::shared_ptr<SSLContext> sslContext(new SSLContext());
334   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
335
336   // connect
337   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
338                                                  sslContext);
339   socket->open();
340
341   // write something to trigger ssl handshake
342   uint8_t buf[128];
343   memset(buf, 'a', sizeof(buf));
344   socket->write(buf, sizeof(buf));
345
346   socket->close();
347   cerr << "WriteError test completed" << endl;
348 }
349
350 /**
351  * Test a socket with TCP_NODELAY unset.
352  */
353 TEST(AsyncSSLSocketTest, SocketWithDelay) {
354   // Start listening on a local port
355   WriteCallbackBase writeCallback;
356   ReadCallback readCallback(&writeCallback);
357   HandshakeCallback handshakeCallback(&readCallback);
358   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
359   TestSSLServer server(&acceptCallback);
360
361   // Set up SSL context.
362   std::shared_ptr<SSLContext> sslContext(new SSLContext());
363   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
364
365   // connect
366   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
367                                                  sslContext);
368   socket->open();
369
370   // write()
371   uint8_t buf[128];
372   memset(buf, 'a', sizeof(buf));
373   socket->write(buf, sizeof(buf));
374
375   // read()
376   uint8_t readbuf[128];
377   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
378   EXPECT_EQ(bytesRead, 128);
379   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
380
381   // close()
382   socket->close();
383
384   cerr << "SocketWithDelay test completed" << endl;
385 }
386
387 using NextProtocolTypePair =
388     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
389
390 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
391   // For matching protos
392  public:
393   void SetUp() override { getctx(clientCtx, serverCtx); }
394
395   void connect(bool unset = false) {
396     getfds(fds);
397
398     if (unset) {
399       // unsetting NPN for any of [client, server] is enough to make NPN not
400       // work
401       clientCtx->unsetNextProtocols();
402     }
403
404     AsyncSSLSocket::UniquePtr clientSock(
405       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
406     AsyncSSLSocket::UniquePtr serverSock(
407       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
408     client = std::make_unique<NpnClient>(std::move(clientSock));
409     server = std::make_unique<NpnServer>(std::move(serverSock));
410
411     eventBase.loop();
412   }
413
414   void expectProtocol(const std::string& proto) {
415     expectHandshakeSuccess();
416     EXPECT_NE(client->nextProtoLength, 0);
417     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
418     EXPECT_EQ(
419         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
420         0);
421     string selected((const char*)client->nextProto, client->nextProtoLength);
422     EXPECT_EQ(proto, selected);
423   }
424
425   void expectNoProtocol() {
426     expectHandshakeSuccess();
427     EXPECT_EQ(client->nextProtoLength, 0);
428     EXPECT_EQ(server->nextProtoLength, 0);
429     EXPECT_EQ(client->nextProto, nullptr);
430     EXPECT_EQ(server->nextProto, nullptr);
431   }
432
433   void expectProtocolType() {
434     expectHandshakeSuccess();
435     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
436         GetParam().second == SSLContext::NextProtocolType::ANY) {
437       EXPECT_EQ(client->protocolType, server->protocolType);
438     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
439                GetParam().second == SSLContext::NextProtocolType::ANY) {
440       // Well not much we can say
441     } else {
442       expectProtocolType(GetParam());
443     }
444   }
445
446   void expectProtocolType(NextProtocolTypePair expected) {
447     expectHandshakeSuccess();
448     EXPECT_EQ(client->protocolType, expected.first);
449     EXPECT_EQ(server->protocolType, expected.second);
450   }
451
452   void expectHandshakeSuccess() {
453     EXPECT_FALSE(client->except.hasValue())
454         << "client handshake error: " << client->except->what();
455     EXPECT_FALSE(server->except.hasValue())
456         << "server handshake error: " << server->except->what();
457   }
458
459   void expectHandshakeError() {
460     EXPECT_TRUE(client->except.hasValue())
461         << "Expected client handshake error!";
462     EXPECT_TRUE(server->except.hasValue())
463         << "Expected server handshake error!";
464   }
465
466   EventBase eventBase;
467   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
468   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
469   int fds[2];
470   std::unique_ptr<NpnClient> client;
471   std::unique_ptr<NpnServer> server;
472 };
473
474 class NextProtocolTLSExtTest : public NextProtocolTest {
475   // For extended TLS protos
476 };
477
478 class NextProtocolNPNOnlyTest : public NextProtocolTest {
479   // For mismatching protos
480 };
481
482 class NextProtocolMismatchTest : public NextProtocolTest {
483   // For mismatching protos
484 };
485
486 TEST_P(NextProtocolTest, NpnTestOverlap) {
487   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
488   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
489                                         GetParam().second);
490
491   connect();
492
493   expectProtocol("baz");
494   expectProtocolType();
495 }
496
497 TEST_P(NextProtocolTest, NpnTestUnset) {
498   // Identical to above test, except that we want unset NPN before
499   // looping.
500   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
501   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
502                                         GetParam().second);
503
504   connect(true /* unset */);
505
506   // if alpn negotiation fails, type will appear as npn
507   expectNoProtocol();
508   EXPECT_EQ(client->protocolType, server->protocolType);
509 }
510
511 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
512   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
513   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
514                                         GetParam().second);
515
516   connect();
517
518   expectNoProtocol();
519   expectProtocolType(
520       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
521 }
522
523 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
524 // will fail on 1.0.2 before that.
525 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
526   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
527   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
528                                         GetParam().second);
529   connect();
530
531   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
532       GetParam().second == SSLContext::NextProtocolType::ALPN) {
533     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
534     // mismatch should result in a fatal alert, but this is the current behavior
535     // on all OpenSSL versions/variants, and we want to know if it changes.
536     expectNoProtocol();
537   }
538 #if FOLLY_OPENSSL_IS_110 || defined(OPENSSL_IS_BORINGSSL)
539   else if (
540       GetParam().first == SSLContext::NextProtocolType::ANY &&
541       GetParam().second == SSLContext::NextProtocolType::ANY) {
542 # if FOLLY_OPENSSL_IS_110
543     // OpenSSL 1.1.0 sends a fatal alert on mismatch, which is probavbly the
544     // correct behavior per RFC7301
545     expectHandshakeError();
546 # else
547     // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
548     // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
549     // NPN *and* ALPN
550     expectNoProtocol();
551 # endif
552   }
553 #endif
554    else {
555     expectProtocol("blub");
556     expectProtocolType(
557         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
558   }
559 }
560
561 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
562   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
563   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
564   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
565                                         GetParam().second);
566
567   connect();
568
569   expectProtocol("ponies");
570   expectProtocolType();
571 }
572
573 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
574   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
575   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
576   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
577                                         GetParam().second);
578
579   connect();
580
581   expectProtocol("blub");
582   expectProtocolType();
583 }
584
585 TEST_P(NextProtocolTest, RandomizedNpnTest) {
586   // Probability that this test will fail is 2^-64, which could be considered
587   // as negligible.
588   const int kTries = 64;
589
590   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
591                                         GetParam().first);
592   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
593                                                   GetParam().second);
594
595   std::set<string> selectedProtocols;
596   for (int i = 0; i < kTries; ++i) {
597     connect();
598
599     EXPECT_NE(client->nextProtoLength, 0);
600     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
601     EXPECT_EQ(
602         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
603         0);
604     string selected((const char*)client->nextProto, client->nextProtoLength);
605     selectedProtocols.insert(selected);
606     expectProtocolType();
607   }
608   EXPECT_EQ(selectedProtocols.size(), 2);
609 }
610
611 INSTANTIATE_TEST_CASE_P(
612     AsyncSSLSocketTest,
613     NextProtocolTest,
614     ::testing::Values(
615         NextProtocolTypePair(
616             SSLContext::NextProtocolType::NPN,
617             SSLContext::NextProtocolType::NPN),
618         NextProtocolTypePair(
619             SSLContext::NextProtocolType::NPN,
620             SSLContext::NextProtocolType::ANY),
621         NextProtocolTypePair(
622             SSLContext::NextProtocolType::ANY,
623             SSLContext::NextProtocolType::ANY)));
624
625 #if FOLLY_OPENSSL_HAS_ALPN
626 INSTANTIATE_TEST_CASE_P(
627     AsyncSSLSocketTest,
628     NextProtocolTLSExtTest,
629     ::testing::Values(
630         NextProtocolTypePair(
631             SSLContext::NextProtocolType::ALPN,
632             SSLContext::NextProtocolType::ALPN),
633         NextProtocolTypePair(
634             SSLContext::NextProtocolType::ALPN,
635             SSLContext::NextProtocolType::ANY),
636         NextProtocolTypePair(
637             SSLContext::NextProtocolType::ANY,
638             SSLContext::NextProtocolType::ALPN)));
639 #endif
640
641 INSTANTIATE_TEST_CASE_P(
642     AsyncSSLSocketTest,
643     NextProtocolNPNOnlyTest,
644     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
645                                            SSLContext::NextProtocolType::NPN)));
646
647 #if FOLLY_OPENSSL_HAS_ALPN
648 INSTANTIATE_TEST_CASE_P(
649     AsyncSSLSocketTest,
650     NextProtocolMismatchTest,
651     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
652                                            SSLContext::NextProtocolType::ALPN),
653                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
654                                            SSLContext::NextProtocolType::NPN)));
655 #endif
656
657 #ifndef OPENSSL_NO_TLSEXT
658 /**
659  * 1. Client sends TLSEXT_HOSTNAME in client hello.
660  * 2. Server found a match SSL_CTX and use this SSL_CTX to
661  *    continue the SSL handshake.
662  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
663  */
664 TEST(AsyncSSLSocketTest, SNITestMatch) {
665   EventBase eventBase;
666   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
667   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
668   // Use the same SSLContext to continue the handshake after
669   // tlsext_hostname match.
670   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
671   const std::string serverName("xyz.newdev.facebook.com");
672   int fds[2];
673   getfds(fds);
674   getctx(clientCtx, dfServerCtx);
675
676   AsyncSSLSocket::UniquePtr clientSock(
677     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
678   AsyncSSLSocket::UniquePtr serverSock(
679     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
680   SNIClient client(std::move(clientSock));
681   SNIServer server(std::move(serverSock),
682                    dfServerCtx,
683                    hskServerCtx,
684                    serverName);
685
686   eventBase.loop();
687
688   EXPECT_TRUE(client.serverNameMatch);
689   EXPECT_TRUE(server.serverNameMatch);
690 }
691
692 /**
693  * 1. Client sends TLSEXT_HOSTNAME in client hello.
694  * 2. Server cannot find a matching SSL_CTX and continue to use
695  *    the current SSL_CTX to do the handshake.
696  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
697  */
698 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
699   EventBase eventBase;
700   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
701   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
702   // Use the same SSLContext to continue the handshake after
703   // tlsext_hostname match.
704   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
705   const std::string clientRequestingServerName("foo.com");
706   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
707
708   int fds[2];
709   getfds(fds);
710   getctx(clientCtx, dfServerCtx);
711
712   AsyncSSLSocket::UniquePtr clientSock(
713     new AsyncSSLSocket(clientCtx,
714                         &eventBase,
715                         fds[0],
716                         clientRequestingServerName));
717   AsyncSSLSocket::UniquePtr serverSock(
718     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
719   SNIClient client(std::move(clientSock));
720   SNIServer server(std::move(serverSock),
721                    dfServerCtx,
722                    hskServerCtx,
723                    serverExpectedServerName);
724
725   eventBase.loop();
726
727   EXPECT_TRUE(!client.serverNameMatch);
728   EXPECT_TRUE(!server.serverNameMatch);
729 }
730 /**
731  * 1. Client sends TLSEXT_HOSTNAME in client hello.
732  * 2. We then change the serverName.
733  * 3. We expect that we get 'false' as the result for serNameMatch.
734  */
735
736 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
737    EventBase eventBase;
738   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
739   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
740   // Use the same SSLContext to continue the handshake after
741   // tlsext_hostname match.
742   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
743   const std::string serverName("xyz.newdev.facebook.com");
744   int fds[2];
745   getfds(fds);
746   getctx(clientCtx, dfServerCtx);
747
748   AsyncSSLSocket::UniquePtr clientSock(
749     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
750   //Change the server name
751   std::string newName("new.com");
752   clientSock->setServerName(newName);
753   AsyncSSLSocket::UniquePtr serverSock(
754     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
755   SNIClient client(std::move(clientSock));
756   SNIServer server(std::move(serverSock),
757                    dfServerCtx,
758                    hskServerCtx,
759                    serverName);
760
761   eventBase.loop();
762
763   EXPECT_TRUE(!client.serverNameMatch);
764 }
765
766 /**
767  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
768  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
769  */
770 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
771   EventBase eventBase;
772   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
773   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
774   // Use the same SSLContext to continue the handshake after
775   // tlsext_hostname match.
776   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
777   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
778
779   int fds[2];
780   getfds(fds);
781   getctx(clientCtx, dfServerCtx);
782
783   AsyncSSLSocket::UniquePtr clientSock(
784     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
785   AsyncSSLSocket::UniquePtr serverSock(
786     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
787   SNIClient client(std::move(clientSock));
788   SNIServer server(std::move(serverSock),
789                    dfServerCtx,
790                    hskServerCtx,
791                    serverExpectedServerName);
792
793   eventBase.loop();
794
795   EXPECT_TRUE(!client.serverNameMatch);
796   EXPECT_TRUE(!server.serverNameMatch);
797 }
798
799 #endif
800 /**
801  * Test SSL client socket
802  */
803 TEST(AsyncSSLSocketTest, SSLClientTest) {
804   // Start listening on a local port
805   WriteCallbackBase writeCallback;
806   ReadCallback readCallback(&writeCallback);
807   HandshakeCallback handshakeCallback(&readCallback);
808   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
809   TestSSLServer server(&acceptCallback);
810
811   // Set up SSL client
812   EventBase eventBase;
813   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
814
815   client->connect();
816   EventBaseAborter eba(&eventBase, 3000);
817   eventBase.loop();
818
819   EXPECT_EQ(client->getMiss(), 1);
820   EXPECT_EQ(client->getHit(), 0);
821
822   cerr << "SSLClientTest test completed" << endl;
823 }
824
825
826 /**
827  * Test SSL client socket session re-use
828  */
829 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
830   // Start listening on a local port
831   WriteCallbackBase writeCallback;
832   ReadCallback readCallback(&writeCallback);
833   HandshakeCallback handshakeCallback(&readCallback);
834   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
835   TestSSLServer server(&acceptCallback);
836
837   // Set up SSL client
838   EventBase eventBase;
839   auto client =
840       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
841
842   client->connect();
843   EventBaseAborter eba(&eventBase, 3000);
844   eventBase.loop();
845
846   EXPECT_EQ(client->getMiss(), 1);
847   EXPECT_EQ(client->getHit(), 9);
848
849   cerr << "SSLClientTestReuse test completed" << endl;
850 }
851
852 /**
853  * Test SSL client socket timeout
854  */
855 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
856   // Start listening on a local port
857   EmptyReadCallback readCallback;
858   HandshakeCallback handshakeCallback(&readCallback,
859                                       HandshakeCallback::EXPECT_ERROR);
860   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
861   TestSSLServer server(&acceptCallback);
862
863   // Set up SSL client
864   EventBase eventBase;
865   auto client =
866       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
867   client->connect(true /* write before connect completes */);
868   EventBaseAborter eba(&eventBase, 3000);
869   eventBase.loop();
870
871   usleep(100000);
872   // This is checking that the connectError callback precedes any queued
873   // writeError callbacks.  This matches AsyncSocket's behavior
874   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
875   EXPECT_EQ(client->getErrors(), 1);
876   EXPECT_EQ(client->getMiss(), 0);
877   EXPECT_EQ(client->getHit(), 0);
878
879   cerr << "SSLClientTimeoutTest test completed" << endl;
880 }
881
882 // The next 3 tests need an FB-only extension, and will fail without it
883 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
884 /**
885  * Test SSL server async cache
886  */
887 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
888   // Start listening on a local port
889   WriteCallbackBase writeCallback;
890   ReadCallback readCallback(&writeCallback);
891   HandshakeCallback handshakeCallback(&readCallback);
892   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
893   TestSSLAsyncCacheServer server(&acceptCallback);
894
895   // Set up SSL client
896   EventBase eventBase;
897   auto client =
898       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
899
900   client->connect();
901   EventBaseAborter eba(&eventBase, 3000);
902   eventBase.loop();
903
904   EXPECT_EQ(server.getAsyncCallbacks(), 18);
905   EXPECT_EQ(server.getAsyncLookups(), 9);
906   EXPECT_EQ(client->getMiss(), 10);
907   EXPECT_EQ(client->getHit(), 0);
908
909   cerr << "SSLServerAsyncCacheTest test completed" << endl;
910 }
911
912 /**
913  * Test SSL server accept timeout with cache path
914  */
915 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
916   // Start listening on a local port
917   WriteCallbackBase writeCallback;
918   ReadCallback readCallback(&writeCallback);
919   HandshakeCallback handshakeCallback(&readCallback);
920   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
921   TestSSLAsyncCacheServer server(&acceptCallback);
922
923   // Set up SSL client
924   EventBase eventBase;
925   // only do a TCP connect
926   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
927   sock->connect(nullptr, server.getAddress());
928
929   EmptyReadCallback clientReadCallback;
930   clientReadCallback.tcpSocket_ = sock;
931   sock->setReadCB(&clientReadCallback);
932
933   EventBaseAborter eba(&eventBase, 3000);
934   eventBase.loop();
935
936   EXPECT_EQ(readCallback.state, STATE_WAITING);
937
938   cerr << "SSLServerTimeoutTest test completed" << endl;
939 }
940
941 /**
942  * Test SSL server accept timeout with cache path
943  */
944 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
945   // Start listening on a local port
946   WriteCallbackBase writeCallback;
947   ReadCallback readCallback(&writeCallback);
948   HandshakeCallback handshakeCallback(&readCallback);
949   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
950   TestSSLAsyncCacheServer server(&acceptCallback);
951
952   // Set up SSL client
953   EventBase eventBase;
954   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
955
956   client->connect();
957   EventBaseAborter eba(&eventBase, 3000);
958   eventBase.loop();
959
960   EXPECT_EQ(server.getAsyncCallbacks(), 1);
961   EXPECT_EQ(server.getAsyncLookups(), 1);
962   EXPECT_EQ(client->getErrors(), 1);
963   EXPECT_EQ(client->getMiss(), 1);
964   EXPECT_EQ(client->getHit(), 0);
965
966   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
967 }
968
969 /**
970  * Test SSL server accept timeout with cache path
971  */
972 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
973   // Start listening on a local port
974   WriteCallbackBase writeCallback;
975   ReadCallback readCallback(&writeCallback);
976   HandshakeCallback handshakeCallback(&readCallback,
977                                       HandshakeCallback::EXPECT_ERROR);
978   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
979   TestSSLAsyncCacheServer server(&acceptCallback, 500);
980
981   // Set up SSL client
982   EventBase eventBase;
983   auto client =
984       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
985
986   client->connect();
987   EventBaseAborter eba(&eventBase, 3000);
988   eventBase.loop();
989
990   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
991       handshakeCallback.closeSocket();});
992   // give time for the cache lookup to come back and find it closed
993   handshakeCallback.waitForHandshake();
994
995   EXPECT_EQ(server.getAsyncCallbacks(), 1);
996   EXPECT_EQ(server.getAsyncLookups(), 1);
997   EXPECT_EQ(client->getErrors(), 1);
998   EXPECT_EQ(client->getMiss(), 1);
999   EXPECT_EQ(client->getHit(), 0);
1000
1001   cerr << "SSLServerCacheCloseTest test completed" << endl;
1002 }
1003 #endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
1004
1005 /**
1006  * Verify Client Ciphers obtained using SSL MSG Callback.
1007  */
1008 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1009   EventBase eventBase;
1010   auto clientCtx = std::make_shared<SSLContext>();
1011   auto serverCtx = std::make_shared<SSLContext>();
1012   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1013   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1014   serverCtx->loadPrivateKey(kTestKey);
1015   serverCtx->loadCertificate(kTestCert);
1016   serverCtx->loadTrustedCertificates(kTestCA);
1017   serverCtx->loadClientCAList(kTestCA);
1018
1019   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1020   clientCtx->ciphers("AES256-SHA:AES128-SHA");
1021   clientCtx->loadPrivateKey(kTestKey);
1022   clientCtx->loadCertificate(kTestCert);
1023   clientCtx->loadTrustedCertificates(kTestCA);
1024
1025   int fds[2];
1026   getfds(fds);
1027
1028   AsyncSSLSocket::UniquePtr clientSock(
1029       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1030   AsyncSSLSocket::UniquePtr serverSock(
1031       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1032
1033   SSLHandshakeClient client(std::move(clientSock), true, true);
1034   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1035
1036   eventBase.loop();
1037
1038 #if defined(OPENSSL_IS_BORINGSSL)
1039   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
1040 #else
1041   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
1042 #endif
1043   EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1044   EXPECT_TRUE(client.handshakeVerify_);
1045   EXPECT_TRUE(client.handshakeSuccess_);
1046   EXPECT_TRUE(!client.handshakeError_);
1047   EXPECT_TRUE(server.handshakeVerify_);
1048   EXPECT_TRUE(server.handshakeSuccess_);
1049   EXPECT_TRUE(!server.handshakeError_);
1050 }
1051
1052 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1053   EventBase eventBase;
1054   auto ctx = std::make_shared<SSLContext>();
1055
1056   int fds[2];
1057   getfds(fds);
1058
1059   int bufLen = 42;
1060   uint8_t majorVersion = 18;
1061   uint8_t minorVersion = 25;
1062
1063   // Create callback buf
1064   auto buf = IOBuf::create(bufLen);
1065   buf->append(bufLen);
1066   folly::io::RWPrivateCursor cursor(buf.get());
1067   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1068   cursor.write<uint16_t>(0);
1069   cursor.write<uint8_t>(38);
1070   cursor.write<uint8_t>(majorVersion);
1071   cursor.write<uint8_t>(minorVersion);
1072   cursor.skip(32);
1073   cursor.write<uint32_t>(0);
1074
1075   SSL* ssl = ctx->createSSL();
1076   SCOPE_EXIT { SSL_free(ssl); };
1077   AsyncSSLSocket::UniquePtr sock(
1078       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1079   sock->enableClientHelloParsing();
1080
1081   // Test client hello parsing in one packet
1082   AsyncSSLSocket::clientHelloParsingCallback(
1083       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1084   buf.reset();
1085
1086   auto parsedClientHello = sock->getClientHelloInfo();
1087   EXPECT_TRUE(parsedClientHello != nullptr);
1088   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1089   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1090 }
1091
1092 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1093   EventBase eventBase;
1094   auto ctx = std::make_shared<SSLContext>();
1095
1096   int fds[2];
1097   getfds(fds);
1098
1099   int bufLen = 42;
1100   uint8_t majorVersion = 18;
1101   uint8_t minorVersion = 25;
1102
1103   // Create callback buf
1104   auto buf = IOBuf::create(bufLen);
1105   buf->append(bufLen);
1106   folly::io::RWPrivateCursor cursor(buf.get());
1107   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1108   cursor.write<uint16_t>(0);
1109   cursor.write<uint8_t>(38);
1110   cursor.write<uint8_t>(majorVersion);
1111   cursor.write<uint8_t>(minorVersion);
1112   cursor.skip(32);
1113   cursor.write<uint32_t>(0);
1114
1115   SSL* ssl = ctx->createSSL();
1116   SCOPE_EXIT { SSL_free(ssl); };
1117   AsyncSSLSocket::UniquePtr sock(
1118       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1119   sock->enableClientHelloParsing();
1120
1121   // Test parsing with two packets with first packet size < 3
1122   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1123   AsyncSSLSocket::clientHelloParsingCallback(
1124       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1125       ssl, sock.get());
1126   bufCopy.reset();
1127   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1128   AsyncSSLSocket::clientHelloParsingCallback(
1129       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1130       ssl, sock.get());
1131   bufCopy.reset();
1132
1133   auto parsedClientHello = sock->getClientHelloInfo();
1134   EXPECT_TRUE(parsedClientHello != nullptr);
1135   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1136   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1137 }
1138
1139 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1140   EventBase eventBase;
1141   auto ctx = std::make_shared<SSLContext>();
1142
1143   int fds[2];
1144   getfds(fds);
1145
1146   int bufLen = 42;
1147   uint8_t majorVersion = 18;
1148   uint8_t minorVersion = 25;
1149
1150   // Create callback buf
1151   auto buf = IOBuf::create(bufLen);
1152   buf->append(bufLen);
1153   folly::io::RWPrivateCursor cursor(buf.get());
1154   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1155   cursor.write<uint16_t>(0);
1156   cursor.write<uint8_t>(38);
1157   cursor.write<uint8_t>(majorVersion);
1158   cursor.write<uint8_t>(minorVersion);
1159   cursor.skip(32);
1160   cursor.write<uint32_t>(0);
1161
1162   SSL* ssl = ctx->createSSL();
1163   SCOPE_EXIT { SSL_free(ssl); };
1164   AsyncSSLSocket::UniquePtr sock(
1165       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1166   sock->enableClientHelloParsing();
1167
1168   // Test parsing with multiple small packets
1169   for (uint64_t i = 0; i < buf->length(); i += 3) {
1170     auto bufCopy = folly::IOBuf::copyBuffer(
1171         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1172     AsyncSSLSocket::clientHelloParsingCallback(
1173         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1174         ssl, sock.get());
1175     bufCopy.reset();
1176   }
1177
1178   auto parsedClientHello = sock->getClientHelloInfo();
1179   EXPECT_TRUE(parsedClientHello != nullptr);
1180   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1181   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1182 }
1183
1184 /**
1185  * Verify sucessful behavior of SSL certificate validation.
1186  */
1187 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1188   EventBase eventBase;
1189   auto clientCtx = std::make_shared<SSLContext>();
1190   auto dfServerCtx = std::make_shared<SSLContext>();
1191
1192   int fds[2];
1193   getfds(fds);
1194   getctx(clientCtx, dfServerCtx);
1195
1196   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1197   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1198
1199   AsyncSSLSocket::UniquePtr clientSock(
1200     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1201   AsyncSSLSocket::UniquePtr serverSock(
1202     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1203
1204   SSLHandshakeClient client(std::move(clientSock), true, true);
1205   clientCtx->loadTrustedCertificates(kTestCA);
1206
1207   SSLHandshakeServer server(std::move(serverSock), true, true);
1208
1209   eventBase.loop();
1210
1211   EXPECT_TRUE(client.handshakeVerify_);
1212   EXPECT_TRUE(client.handshakeSuccess_);
1213   EXPECT_TRUE(!client.handshakeError_);
1214   EXPECT_LE(0, client.handshakeTime.count());
1215   EXPECT_TRUE(!server.handshakeVerify_);
1216   EXPECT_TRUE(server.handshakeSuccess_);
1217   EXPECT_TRUE(!server.handshakeError_);
1218   EXPECT_LE(0, server.handshakeTime.count());
1219 }
1220
1221 /**
1222  * Verify that the client's verification callback is able to fail SSL
1223  * connection establishment.
1224  */
1225 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1226   EventBase eventBase;
1227   auto clientCtx = std::make_shared<SSLContext>();
1228   auto dfServerCtx = std::make_shared<SSLContext>();
1229
1230   int fds[2];
1231   getfds(fds);
1232   getctx(clientCtx, dfServerCtx);
1233
1234   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1235   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1236
1237   AsyncSSLSocket::UniquePtr clientSock(
1238     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1239   AsyncSSLSocket::UniquePtr serverSock(
1240     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1241
1242   SSLHandshakeClient client(std::move(clientSock), true, false);
1243   clientCtx->loadTrustedCertificates(kTestCA);
1244
1245   SSLHandshakeServer server(std::move(serverSock), true, true);
1246
1247   eventBase.loop();
1248
1249   EXPECT_TRUE(client.handshakeVerify_);
1250   EXPECT_TRUE(!client.handshakeSuccess_);
1251   EXPECT_TRUE(client.handshakeError_);
1252   EXPECT_LE(0, client.handshakeTime.count());
1253   EXPECT_TRUE(!server.handshakeVerify_);
1254   EXPECT_TRUE(!server.handshakeSuccess_);
1255   EXPECT_TRUE(server.handshakeError_);
1256   EXPECT_LE(0, server.handshakeTime.count());
1257 }
1258
1259 /**
1260  * Verify that the options in SSLContext can be overridden in
1261  * sslConnect/Accept.i.e specifying that no validation should be performed
1262  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1263  * the validation callback.
1264  */
1265 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1266   EventBase eventBase;
1267   auto clientCtx = std::make_shared<SSLContext>();
1268   auto dfServerCtx = std::make_shared<SSLContext>();
1269
1270   int fds[2];
1271   getfds(fds);
1272   getctx(clientCtx, dfServerCtx);
1273
1274   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1275   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1276
1277   AsyncSSLSocket::UniquePtr clientSock(
1278     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1279   AsyncSSLSocket::UniquePtr serverSock(
1280     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1281
1282   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1283   clientCtx->loadTrustedCertificates(kTestCA);
1284
1285   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1286
1287   eventBase.loop();
1288
1289   EXPECT_TRUE(!client.handshakeVerify_);
1290   EXPECT_TRUE(client.handshakeSuccess_);
1291   EXPECT_TRUE(!client.handshakeError_);
1292   EXPECT_LE(0, client.handshakeTime.count());
1293   EXPECT_TRUE(!server.handshakeVerify_);
1294   EXPECT_TRUE(server.handshakeSuccess_);
1295   EXPECT_TRUE(!server.handshakeError_);
1296   EXPECT_LE(0, server.handshakeTime.count());
1297 }
1298
1299 /**
1300  * Verify that the options in SSLContext can be overridden in
1301  * sslConnect/Accept. Enable verification even if context says otherwise.
1302  * Test requireClientCert with client cert
1303  */
1304 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1305   EventBase eventBase;
1306   auto clientCtx = std::make_shared<SSLContext>();
1307   auto serverCtx = std::make_shared<SSLContext>();
1308   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1309   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1310   serverCtx->loadPrivateKey(kTestKey);
1311   serverCtx->loadCertificate(kTestCert);
1312   serverCtx->loadTrustedCertificates(kTestCA);
1313   serverCtx->loadClientCAList(kTestCA);
1314
1315   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1316   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1317   clientCtx->loadPrivateKey(kTestKey);
1318   clientCtx->loadCertificate(kTestCert);
1319   clientCtx->loadTrustedCertificates(kTestCA);
1320
1321   int fds[2];
1322   getfds(fds);
1323
1324   AsyncSSLSocket::UniquePtr clientSock(
1325       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1326   AsyncSSLSocket::UniquePtr serverSock(
1327       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1328
1329   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1330   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1331
1332   eventBase.loop();
1333
1334   EXPECT_TRUE(client.handshakeVerify_);
1335   EXPECT_TRUE(client.handshakeSuccess_);
1336   EXPECT_FALSE(client.handshakeError_);
1337   EXPECT_LE(0, client.handshakeTime.count());
1338   EXPECT_TRUE(server.handshakeVerify_);
1339   EXPECT_TRUE(server.handshakeSuccess_);
1340   EXPECT_FALSE(server.handshakeError_);
1341   EXPECT_LE(0, server.handshakeTime.count());
1342 }
1343
1344 /**
1345  * Verify that the client's verification callback is able to override
1346  * the preverification failure and allow a successful connection.
1347  */
1348 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1349   EventBase eventBase;
1350   auto clientCtx = std::make_shared<SSLContext>();
1351   auto dfServerCtx = std::make_shared<SSLContext>();
1352
1353   int fds[2];
1354   getfds(fds);
1355   getctx(clientCtx, dfServerCtx);
1356
1357   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1358   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1359
1360   AsyncSSLSocket::UniquePtr clientSock(
1361     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1362   AsyncSSLSocket::UniquePtr serverSock(
1363     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1364
1365   SSLHandshakeClient client(std::move(clientSock), false, true);
1366   SSLHandshakeServer server(std::move(serverSock), true, true);
1367
1368   eventBase.loop();
1369
1370   EXPECT_TRUE(client.handshakeVerify_);
1371   EXPECT_TRUE(client.handshakeSuccess_);
1372   EXPECT_TRUE(!client.handshakeError_);
1373   EXPECT_LE(0, client.handshakeTime.count());
1374   EXPECT_TRUE(!server.handshakeVerify_);
1375   EXPECT_TRUE(server.handshakeSuccess_);
1376   EXPECT_TRUE(!server.handshakeError_);
1377   EXPECT_LE(0, server.handshakeTime.count());
1378 }
1379
1380 /**
1381  * Verify that specifying that no validation should be performed allows an
1382  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1383  * callback.
1384  */
1385 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1386   EventBase eventBase;
1387   auto clientCtx = std::make_shared<SSLContext>();
1388   auto dfServerCtx = std::make_shared<SSLContext>();
1389
1390   int fds[2];
1391   getfds(fds);
1392   getctx(clientCtx, dfServerCtx);
1393
1394   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1395   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1396
1397   AsyncSSLSocket::UniquePtr clientSock(
1398     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1399   AsyncSSLSocket::UniquePtr serverSock(
1400     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1401
1402   SSLHandshakeClient client(std::move(clientSock), false, false);
1403   SSLHandshakeServer server(std::move(serverSock), false, false);
1404
1405   eventBase.loop();
1406
1407   EXPECT_TRUE(!client.handshakeVerify_);
1408   EXPECT_TRUE(client.handshakeSuccess_);
1409   EXPECT_TRUE(!client.handshakeError_);
1410   EXPECT_LE(0, client.handshakeTime.count());
1411   EXPECT_TRUE(!server.handshakeVerify_);
1412   EXPECT_TRUE(server.handshakeSuccess_);
1413   EXPECT_TRUE(!server.handshakeError_);
1414   EXPECT_LE(0, server.handshakeTime.count());
1415 }
1416
1417 /**
1418  * Test requireClientCert with client cert
1419  */
1420 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1421   EventBase eventBase;
1422   auto clientCtx = std::make_shared<SSLContext>();
1423   auto serverCtx = std::make_shared<SSLContext>();
1424   serverCtx->setVerificationOption(
1425       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1426   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1427   serverCtx->loadPrivateKey(kTestKey);
1428   serverCtx->loadCertificate(kTestCert);
1429   serverCtx->loadTrustedCertificates(kTestCA);
1430   serverCtx->loadClientCAList(kTestCA);
1431
1432   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1433   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1434   clientCtx->loadPrivateKey(kTestKey);
1435   clientCtx->loadCertificate(kTestCert);
1436   clientCtx->loadTrustedCertificates(kTestCA);
1437
1438   int fds[2];
1439   getfds(fds);
1440
1441   AsyncSSLSocket::UniquePtr clientSock(
1442       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1443   AsyncSSLSocket::UniquePtr serverSock(
1444       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1445
1446   SSLHandshakeClient client(std::move(clientSock), true, true);
1447   SSLHandshakeServer server(std::move(serverSock), true, true);
1448
1449   eventBase.loop();
1450
1451   EXPECT_TRUE(client.handshakeVerify_);
1452   EXPECT_TRUE(client.handshakeSuccess_);
1453   EXPECT_FALSE(client.handshakeError_);
1454   EXPECT_LE(0, client.handshakeTime.count());
1455   EXPECT_TRUE(server.handshakeVerify_);
1456   EXPECT_TRUE(server.handshakeSuccess_);
1457   EXPECT_FALSE(server.handshakeError_);
1458   EXPECT_LE(0, server.handshakeTime.count());
1459 }
1460
1461
1462 /**
1463  * Test requireClientCert with no client cert
1464  */
1465 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1466   EventBase eventBase;
1467   auto clientCtx = std::make_shared<SSLContext>();
1468   auto serverCtx = std::make_shared<SSLContext>();
1469   serverCtx->setVerificationOption(
1470       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1471   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1472   serverCtx->loadPrivateKey(kTestKey);
1473   serverCtx->loadCertificate(kTestCert);
1474   serverCtx->loadTrustedCertificates(kTestCA);
1475   serverCtx->loadClientCAList(kTestCA);
1476   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1477   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1478
1479   int fds[2];
1480   getfds(fds);
1481
1482   AsyncSSLSocket::UniquePtr clientSock(
1483       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1484   AsyncSSLSocket::UniquePtr serverSock(
1485       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1486
1487   SSLHandshakeClient client(std::move(clientSock), false, false);
1488   SSLHandshakeServer server(std::move(serverSock), false, false);
1489
1490   eventBase.loop();
1491
1492   EXPECT_FALSE(server.handshakeVerify_);
1493   EXPECT_FALSE(server.handshakeSuccess_);
1494   EXPECT_TRUE(server.handshakeError_);
1495   EXPECT_LE(0, client.handshakeTime.count());
1496   EXPECT_LE(0, server.handshakeTime.count());
1497 }
1498
1499 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1500   auto cert = getFileAsBuf(kTestCert);
1501   auto key = getFileAsBuf(kTestKey);
1502
1503   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1504   BIO_write(certBio.get(), cert.data(), cert.size());
1505   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1506   BIO_write(keyBio.get(), key.data(), key.size());
1507
1508   // Create SSL structs from buffers to get properties
1509   ssl::X509UniquePtr certStruct(
1510       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1511   ssl::EvpPkeyUniquePtr keyStruct(
1512       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1513   certBio = nullptr;
1514   keyBio = nullptr;
1515
1516   auto origCommonName = getCommonName(certStruct.get());
1517   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1518   certStruct = nullptr;
1519   keyStruct = nullptr;
1520
1521   auto ctx = std::make_shared<SSLContext>();
1522   ctx->loadPrivateKeyFromBufferPEM(key);
1523   ctx->loadCertificateFromBufferPEM(cert);
1524   ctx->loadTrustedCertificates(kTestCA);
1525
1526   ssl::SSLUniquePtr ssl(ctx->createSSL());
1527
1528   auto newCert = SSL_get_certificate(ssl.get());
1529   auto newKey = SSL_get_privatekey(ssl.get());
1530
1531   // Get properties from SSL struct
1532   auto newCommonName = getCommonName(newCert);
1533   auto newKeySize = EVP_PKEY_bits(newKey);
1534
1535   // Check that the key and cert have the expected properties
1536   EXPECT_EQ(origCommonName, newCommonName);
1537   EXPECT_EQ(origKeySize, newKeySize);
1538 }
1539
1540 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1541   EventBase eb;
1542
1543   // Set up SSL context.
1544   auto sslContext = std::make_shared<SSLContext>();
1545   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1546
1547   // create SSL socket
1548   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1549
1550   EXPECT_EQ(1500, socket->getMinWriteSize());
1551
1552   socket->setMinWriteSize(0);
1553   EXPECT_EQ(0, socket->getMinWriteSize());
1554   socket->setMinWriteSize(50000);
1555   EXPECT_EQ(50000, socket->getMinWriteSize());
1556 }
1557
1558 class ReadCallbackTerminator : public ReadCallback {
1559  public:
1560   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1561       : ReadCallback(wcb)
1562       , base_(base) {}
1563
1564   // Do not write data back, terminate the loop.
1565   void readDataAvailable(size_t len) noexcept override {
1566     std::cerr << "readDataAvailable, len " << len << std::endl;
1567
1568     currentBuffer.length = len;
1569
1570     buffers.push_back(currentBuffer);
1571     currentBuffer.reset();
1572     state = STATE_SUCCEEDED;
1573
1574     socket_->setReadCB(nullptr);
1575     base_->terminateLoopSoon();
1576   }
1577  private:
1578   EventBase* base_;
1579 };
1580
1581
1582 /**
1583  * Test a full unencrypted codepath
1584  */
1585 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1586   EventBase base;
1587
1588   auto clientCtx = std::make_shared<folly::SSLContext>();
1589   auto serverCtx = std::make_shared<folly::SSLContext>();
1590   int fds[2];
1591   getfds(fds);
1592   getctx(clientCtx, serverCtx);
1593   auto client = AsyncSSLSocket::newSocket(
1594                   clientCtx, &base, fds[0], false, true);
1595   auto server = AsyncSSLSocket::newSocket(
1596                   serverCtx, &base, fds[1], true, true);
1597
1598   ReadCallbackTerminator readCallback(&base, nullptr);
1599   server->setReadCB(&readCallback);
1600   readCallback.setSocket(server);
1601
1602   uint8_t buf[128];
1603   memset(buf, 'a', sizeof(buf));
1604   client->write(nullptr, buf, sizeof(buf));
1605
1606   // Check that bytes are unencrypted
1607   char c;
1608   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1609   EXPECT_EQ('a', c);
1610
1611   EventBaseAborter eba(&base, 3000);
1612   base.loop();
1613
1614   EXPECT_EQ(1, readCallback.buffers.size());
1615   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1616
1617   server->setReadCB(&readCallback);
1618
1619   // Unencrypted
1620   server->sslAccept(nullptr);
1621   client->sslConn(nullptr);
1622
1623   // Do NOT wait for handshake, writing should be queued and happen after
1624
1625   client->write(nullptr, buf, sizeof(buf));
1626
1627   // Check that bytes are *not* unencrypted
1628   char c2;
1629   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1630   EXPECT_NE('a', c2);
1631
1632
1633   base.loop();
1634
1635   EXPECT_EQ(2, readCallback.buffers.size());
1636   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1637 }
1638
1639 TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) {
1640   auto clientCtx = std::make_shared<folly::SSLContext>();
1641   auto serverCtx = std::make_shared<folly::SSLContext>();
1642   getctx(clientCtx, serverCtx);
1643
1644   WriteCallbackBase writeCallback;
1645   ReadCallback readCallback(&writeCallback);
1646   HandshakeCallback handshakeCallback(&readCallback);
1647   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1648   TestSSLServer server(&acceptCallback);
1649
1650   EventBase evb;
1651   std::shared_ptr<AsyncSSLSocket> socket =
1652       AsyncSSLSocket::newSocket(clientCtx, &evb, true);
1653   socket->connect(nullptr, server.getAddress(), 0);
1654
1655   evb.loop();
1656
1657   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, socket->getSSLState());
1658   socket->sslConn(nullptr);
1659   evb.loop();
1660   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, socket->getSSLState());
1661
1662   // write()
1663   std::array<uint8_t, 128> buf;
1664   memset(buf.data(), 'a', buf.size());
1665   socket->write(nullptr, buf.data(), buf.size());
1666
1667   socket->close();
1668 }
1669
1670 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1671   // Start listening on a local port
1672   WriteCallbackBase writeCallback;
1673   WriteErrorCallback readCallback(&writeCallback);
1674   HandshakeCallback handshakeCallback(&readCallback,
1675                                       HandshakeCallback::EXPECT_ERROR);
1676   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1677   TestSSLServer server(&acceptCallback);
1678
1679   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1680   socket->open();
1681   uint8_t buf[3] = {0x16, 0x03, 0x01};
1682   socket->write(buf, sizeof(buf));
1683   socket->closeWithReset();
1684
1685   handshakeCallback.waitForHandshake();
1686   EXPECT_NE(
1687       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1688   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1689 }
1690
1691 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1692   // Start listening on a local port
1693   WriteCallbackBase writeCallback;
1694   WriteErrorCallback readCallback(&writeCallback);
1695   HandshakeCallback handshakeCallback(&readCallback,
1696                                       HandshakeCallback::EXPECT_ERROR);
1697   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1698   TestSSLServer server(&acceptCallback);
1699
1700   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1701   socket->open();
1702   uint8_t buf[3] = {0x16, 0x03, 0x01};
1703   socket->write(buf, sizeof(buf));
1704   socket->close();
1705
1706   handshakeCallback.waitForHandshake();
1707 #if FOLLY_OPENSSL_IS_110
1708   EXPECT_NE(
1709       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1710 #else
1711   EXPECT_NE(
1712       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1713 #endif
1714 }
1715
1716 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1717   // Start listening on a local port
1718   WriteCallbackBase writeCallback;
1719   WriteErrorCallback readCallback(&writeCallback);
1720   HandshakeCallback handshakeCallback(&readCallback,
1721                                       HandshakeCallback::EXPECT_ERROR);
1722   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1723   TestSSLServer server(&acceptCallback);
1724
1725   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1726   socket->open();
1727   uint8_t buf[256] = {0x16, 0x03};
1728   memset(buf + 2, 'a', sizeof(buf) - 2);
1729   socket->write(buf, sizeof(buf));
1730   socket->close();
1731
1732   handshakeCallback.waitForHandshake();
1733   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1734             std::string::npos);
1735 #if defined(OPENSSL_IS_BORINGSSL)
1736   EXPECT_NE(
1737       handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
1738       std::string::npos);
1739 #elif FOLLY_OPENSSL_IS_110
1740   EXPECT_NE(handshakeCallback.errorString_.find("packet length too long"),
1741             std::string::npos);
1742 #else
1743   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1744             std::string::npos);
1745 #endif
1746 }
1747
1748 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
1749   using folly::ssl::OpenSSLUtils;
1750   EXPECT_EQ(
1751       OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
1752   // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
1753   EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
1754   // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
1755   EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
1756 }
1757
1758 #if FOLLY_ALLOW_TFO
1759
1760 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1761  public:
1762   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1763
1764   explicit MockAsyncTFOSSLSocket(
1765       std::shared_ptr<folly::SSLContext> sslCtx,
1766       EventBase* evb)
1767       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1768
1769   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1770 };
1771
1772 /**
1773  * Test connecting to, writing to, reading from, and closing the
1774  * connection to the SSL server with TFO.
1775  */
1776 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1777   // Start listening on a local port
1778   WriteCallbackBase writeCallback;
1779   ReadCallback readCallback(&writeCallback);
1780   HandshakeCallback handshakeCallback(&readCallback);
1781   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1782   TestSSLServer server(&acceptCallback, true);
1783
1784   // Set up SSL context.
1785   auto sslContext = std::make_shared<SSLContext>();
1786
1787   // connect
1788   auto socket =
1789       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1790   socket->enableTFO();
1791   socket->open();
1792
1793   // write()
1794   std::array<uint8_t, 128> buf;
1795   memset(buf.data(), 'a', buf.size());
1796   socket->write(buf.data(), buf.size());
1797
1798   // read()
1799   std::array<uint8_t, 128> readbuf;
1800   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1801   EXPECT_EQ(bytesRead, 128);
1802   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1803
1804   // close()
1805   socket->close();
1806 }
1807
1808 /**
1809  * Test connecting to, writing to, reading from, and closing the
1810  * connection to the SSL server with TFO.
1811  */
1812 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1813   // Start listening on a local port
1814   WriteCallbackBase writeCallback;
1815   ReadCallback readCallback(&writeCallback);
1816   HandshakeCallback handshakeCallback(&readCallback);
1817   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1818   TestSSLServer server(&acceptCallback, false);
1819
1820   // Set up SSL context.
1821   auto sslContext = std::make_shared<SSLContext>();
1822
1823   // connect
1824   auto socket =
1825       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1826   socket->enableTFO();
1827   socket->open();
1828
1829   // write()
1830   std::array<uint8_t, 128> buf;
1831   memset(buf.data(), 'a', buf.size());
1832   socket->write(buf.data(), buf.size());
1833
1834   // read()
1835   std::array<uint8_t, 128> readbuf;
1836   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1837   EXPECT_EQ(bytesRead, 128);
1838   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1839
1840   // close()
1841   socket->close();
1842 }
1843
1844 class ConnCallback : public AsyncSocket::ConnectCallback {
1845  public:
1846   void connectSuccess() noexcept override {
1847     state = State::SUCCESS;
1848   }
1849
1850   void connectErr(const AsyncSocketException& ex) noexcept override {
1851     state = State::ERROR;
1852     error = ex.what();
1853   }
1854
1855   enum class State { WAITING, SUCCESS, ERROR };
1856
1857   State state{State::WAITING};
1858   std::string error;
1859 };
1860
1861 template <class Cardinality>
1862 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1863     EventBase* evb,
1864     const SocketAddress& address,
1865     Cardinality cardinality) {
1866   // Set up SSL context.
1867   auto sslContext = std::make_shared<SSLContext>();
1868
1869   // connect
1870   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1871       new MockAsyncTFOSSLSocket(sslContext, evb));
1872   socket->enableTFO();
1873
1874   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1875       .Times(cardinality)
1876       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1877         sockaddr_storage addr;
1878         auto len = address.getAddress(&addr);
1879         return connect(fd, (const struct sockaddr*)&addr, len);
1880       }));
1881   return socket;
1882 }
1883
1884 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1885   // Start listening on a local port
1886   WriteCallbackBase writeCallback;
1887   ReadCallback readCallback(&writeCallback);
1888   HandshakeCallback handshakeCallback(&readCallback);
1889   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1890   TestSSLServer server(&acceptCallback, true);
1891
1892   EventBase evb;
1893
1894   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1895   ConnCallback ccb;
1896   socket->connect(&ccb, server.getAddress(), 30);
1897
1898   evb.loop();
1899   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1900
1901   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1902   evb.loop();
1903
1904   BlockingSocket sock(std::move(socket));
1905   // write()
1906   std::array<uint8_t, 128> buf;
1907   memset(buf.data(), 'a', buf.size());
1908   sock.write(buf.data(), buf.size());
1909
1910   // read()
1911   std::array<uint8_t, 128> readbuf;
1912   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1913   EXPECT_EQ(bytesRead, 128);
1914   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1915
1916   // close()
1917   sock.close();
1918 }
1919
1920 #if !defined(OPENSSL_IS_BORINGSSL)
1921 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1922   // Start listening on a local port
1923   ConnectTimeoutCallback acceptCallback;
1924   TestSSLServer server(&acceptCallback, true);
1925
1926   // Set up SSL context.
1927   auto sslContext = std::make_shared<SSLContext>();
1928
1929   // connect
1930   auto socket =
1931       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1932   socket->enableTFO();
1933   EXPECT_THROW(
1934       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
1935 }
1936 #endif
1937
1938 #if !defined(OPENSSL_IS_BORINGSSL)
1939 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1940   // Start listening on a local port
1941   ConnectTimeoutCallback acceptCallback;
1942   TestSSLServer server(&acceptCallback, true);
1943
1944   EventBase evb;
1945
1946   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1947   ConnCallback ccb;
1948   // Set a short timeout
1949   socket->connect(&ccb, server.getAddress(), 1);
1950
1951   evb.loop();
1952   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1953 }
1954 #endif
1955
1956 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
1957   // Start listening on a local port
1958   EmptyReadCallback readCallback;
1959   HandshakeCallback handshakeCallback(
1960       &readCallback, HandshakeCallback::EXPECT_ERROR);
1961   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
1962   TestSSLServer server(&acceptCallback, true);
1963
1964   EventBase evb;
1965
1966   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1967   ConnCallback ccb;
1968   socket->connect(&ccb, server.getAddress(), 100);
1969
1970   evb.loop();
1971   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1972   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
1973 }
1974
1975 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
1976   // Start listening on a local port
1977   EventBase evb;
1978
1979   // Hopefully nothing is listening on this address
1980   SocketAddress addr("127.0.0.1", 65535);
1981   auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
1982   ConnCallback ccb;
1983   socket->connect(&ccb, addr, 100);
1984
1985   evb.loop();
1986   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1987   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
1988 }
1989
1990 TEST(AsyncSSLSocketTest, TestPreReceivedData) {
1991   EventBase clientEventBase;
1992   EventBase serverEventBase;
1993   auto clientCtx = std::make_shared<SSLContext>();
1994   auto dfServerCtx = std::make_shared<SSLContext>();
1995   std::array<int, 2> fds;
1996   getfds(fds.data());
1997   getctx(clientCtx, dfServerCtx);
1998
1999   AsyncSSLSocket::UniquePtr clientSockPtr(
2000       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
2001   AsyncSSLSocket::UniquePtr serverSockPtr(
2002       new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
2003   auto clientSock = clientSockPtr.get();
2004   auto serverSock = serverSockPtr.get();
2005   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
2006
2007   // Steal some data from the server.
2008   clientEventBase.loopOnce();
2009   std::array<uint8_t, 10> buf;
2010   recv(fds[1], buf.data(), buf.size(), 0);
2011
2012   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
2013   SSLHandshakeServer server(std::move(serverSockPtr), true, true);
2014   while (!client.handshakeSuccess_ && !client.handshakeError_) {
2015     serverEventBase.loopOnce();
2016     clientEventBase.loopOnce();
2017   }
2018
2019   EXPECT_TRUE(client.handshakeSuccess_);
2020   EXPECT_TRUE(server.handshakeSuccess_);
2021   EXPECT_EQ(
2022       serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
2023 }
2024
2025 TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) {
2026   EventBase clientEventBase;
2027   EventBase serverEventBase;
2028   auto clientCtx = std::make_shared<SSLContext>();
2029   auto dfServerCtx = std::make_shared<SSLContext>();
2030   std::array<int, 2> fds;
2031   getfds(fds.data());
2032   getctx(clientCtx, dfServerCtx);
2033
2034   AsyncSSLSocket::UniquePtr clientSockPtr(
2035       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
2036   AsyncSocket::UniquePtr serverSockPtr(
2037       new AsyncSocket(&serverEventBase, fds[1]));
2038   auto clientSock = clientSockPtr.get();
2039   auto serverSock = serverSockPtr.get();
2040   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
2041
2042   // Steal some data from the server.
2043   clientEventBase.loopOnce();
2044   std::array<uint8_t, 10> buf;
2045   recv(fds[1], buf.data(), buf.size(), 0);
2046   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
2047   AsyncSSLSocket::UniquePtr serverSSLSockPtr(
2048       new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true));
2049   auto serverSSLSock = serverSSLSockPtr.get();
2050   SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true);
2051   while (!client.handshakeSuccess_ && !client.handshakeError_) {
2052     serverEventBase.loopOnce();
2053     clientEventBase.loopOnce();
2054   }
2055
2056   EXPECT_TRUE(client.handshakeSuccess_);
2057   EXPECT_TRUE(server.handshakeSuccess_);
2058   EXPECT_EQ(
2059       serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
2060 }
2061
2062 /**
2063  * Test overriding the flags passed to "sendmsg()" system call,
2064  * and verifying that write requests fail properly.
2065  */
2066 TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
2067   // Start listening on a local port
2068   SendMsgFlagsCallback msgCallback;
2069   ExpectWriteErrorCallback writeCallback(&msgCallback);
2070   ReadCallback readCallback(&writeCallback);
2071   HandshakeCallback handshakeCallback(&readCallback);
2072   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2073   TestSSLServer server(&acceptCallback);
2074
2075   // Set up SSL context.
2076   auto sslContext = std::make_shared<SSLContext>();
2077   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2078
2079   // connect
2080   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
2081                                                  sslContext);
2082   socket->open();
2083
2084   // Setting flags to "-1" to trigger "Invalid argument" error
2085   // on attempt to use this flags in sendmsg() system call.
2086   msgCallback.resetFlags(-1);
2087
2088   // write()
2089   std::vector<uint8_t> buf(128, 'a');
2090   ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
2091
2092   // close()
2093   socket->close();
2094
2095   cerr << "SendMsgParamsCallback test completed" << endl;
2096 }
2097
2098 #ifdef MSG_ERRQUEUE
2099 /**
2100  * Test connecting to, writing to, reading from, and closing the
2101  * connection to the SSL server.
2102  */
2103 TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
2104   // This test requires Linux kernel v4.6 or later
2105   struct utsname s_uname;
2106   memset(&s_uname, 0, sizeof(s_uname));
2107   ASSERT_EQ(uname(&s_uname), 0);
2108   int major, minor;
2109   folly::StringPiece extra;
2110   if (folly::split<false>(
2111         '.', std::string(s_uname.release) + ".", major, minor, extra)) {
2112     if (major < 4 || (major == 4 && minor < 6)) {
2113       LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
2114                 << "kernel ver. " << s_uname.release << " detected).";
2115       return;
2116     }
2117   }
2118
2119   // Start listening on a local port
2120   SendMsgDataCallback msgCallback;
2121   WriteCheckTimestampCallback writeCallback(&msgCallback);
2122   ReadCallback readCallback(&writeCallback);
2123   HandshakeCallback handshakeCallback(&readCallback);
2124   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2125   TestSSLServer server(&acceptCallback);
2126
2127   // Set up SSL context.
2128   auto sslContext = std::make_shared<SSLContext>();
2129   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2130
2131   // connect
2132   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
2133                                                  sslContext);
2134   socket->open();
2135
2136   // Adding MSG_EOR flag to the message flags - it'll trigger
2137   // timestamp generation for the last byte of the message.
2138   msgCallback.resetFlags(MSG_DONTWAIT|MSG_NOSIGNAL|MSG_EOR);
2139
2140   // Init ancillary data buffer to trigger timestamp notification
2141   union {
2142     uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
2143     struct cmsghdr cmsg;
2144   } u;
2145   u.cmsg.cmsg_level = SOL_SOCKET;
2146   u.cmsg.cmsg_type = SO_TIMESTAMPING;
2147   u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
2148   uint32_t flags =
2149       SOF_TIMESTAMPING_TX_SCHED |
2150       SOF_TIMESTAMPING_TX_SOFTWARE |
2151       SOF_TIMESTAMPING_TX_ACK;
2152   memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
2153   std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
2154   memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
2155   msgCallback.resetData(std::move(ctrl));
2156
2157   // write()
2158   std::vector<uint8_t> buf(128, 'a');
2159   socket->write(buf.data(), buf.size());
2160
2161   // read()
2162   std::vector<uint8_t> readbuf(buf.size());
2163   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2164   EXPECT_EQ(bytesRead, buf.size());
2165   EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
2166
2167   writeCallback.checkForTimestampNotifications();
2168
2169   // close()
2170   socket->close();
2171
2172   cerr << "SendMsgDataCallback test completed" << endl;
2173 }
2174 #endif // MSG_ERRQUEUE
2175
2176 #endif
2177
2178 } // namespace
2179
2180 #ifdef SIGPIPE
2181 ///////////////////////////////////////////////////////////////////////////
2182 // init_unit_test_suite
2183 ///////////////////////////////////////////////////////////////////////////
2184 namespace {
2185 struct Initializer {
2186   Initializer() {
2187     signal(SIGPIPE, SIG_IGN);
2188   }
2189 };
2190 Initializer initializer;
2191 } // anonymous
2192 #endif