2017
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <pthread.h>
19 #include <signal.h>
20
21 #include <folly/SocketAddress.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/EventBase.h>
24 #include <folly/portability/GMock.h>
25 #include <folly/portability/GTest.h>
26 #include <folly/portability/OpenSSL.h>
27 #include <folly/portability/Sockets.h>
28 #include <folly/portability/Unistd.h>
29
30 #include <folly/io/async/test/BlockingSocket.h>
31
32 #include <fcntl.h>
33 #include <folly/io/Cursor.h>
34 #include <openssl/bio.h>
35 #include <sys/types.h>
36 #include <fstream>
37 #include <iostream>
38 #include <list>
39 #include <set>
40 #include <thread>
41
42 using std::string;
43 using std::vector;
44 using std::min;
45 using std::cerr;
46 using std::endl;
47 using std::list;
48
49 using namespace testing;
50
51 namespace folly {
52 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
53 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
54 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
55
56 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
57 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
58 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
59
60 constexpr size_t SSLClient::kMaxReadBufferSz;
61 constexpr size_t SSLClient::kMaxReadsPerEvent;
62
63 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
64     : ctx_(new folly::SSLContext),
65       acb_(acb),
66       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
67   // Set up the SSL context
68   ctx_->loadCertificate(testCert);
69   ctx_->loadPrivateKey(testKey);
70   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
71
72   acb_->ctx_ = ctx_;
73   acb_->base_ = &evb_;
74
75   // Enable TFO
76   if (enableTFO) {
77     LOG(INFO) << "server TFO enabled";
78     socket_->setTFOEnabled(true, 1000);
79   }
80
81   // set up the listening socket
82   socket_->bind(0);
83   socket_->getAddress(&address_);
84   socket_->listen(100);
85   socket_->addAcceptCallback(acb_, &evb_);
86   socket_->startAccepting();
87
88   int ret = pthread_create(&thread_, nullptr, Main, this);
89   assert(ret == 0);
90   (void)ret;
91
92   std::cerr << "Accepting connections on " << address_ << std::endl;
93 }
94
95 void getfds(int fds[2]) {
96   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
97     FAIL() << "failed to create socketpair: " << strerror(errno);
98   }
99   for (int idx = 0; idx < 2; ++idx) {
100     int flags = fcntl(fds[idx], F_GETFL, 0);
101     if (flags == -1) {
102       FAIL() << "failed to get flags for socket " << idx << ": "
103              << strerror(errno);
104     }
105     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
106       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
107              << strerror(errno);
108     }
109   }
110 }
111
112 void getctx(
113   std::shared_ptr<folly::SSLContext> clientCtx,
114   std::shared_ptr<folly::SSLContext> serverCtx) {
115   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
116
117   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
118   serverCtx->loadCertificate(
119       testCert);
120   serverCtx->loadPrivateKey(
121       testKey);
122 }
123
124 void sslsocketpair(
125   EventBase* eventBase,
126   AsyncSSLSocket::UniquePtr* clientSock,
127   AsyncSSLSocket::UniquePtr* serverSock) {
128   auto clientCtx = std::make_shared<folly::SSLContext>();
129   auto serverCtx = std::make_shared<folly::SSLContext>();
130   int fds[2];
131   getfds(fds);
132   getctx(clientCtx, serverCtx);
133   clientSock->reset(new AsyncSSLSocket(
134                       clientCtx, eventBase, fds[0], false));
135   serverSock->reset(new AsyncSSLSocket(
136                       serverCtx, eventBase, fds[1], true));
137
138   // (*clientSock)->setSendTimeout(100);
139   // (*serverSock)->setSendTimeout(100);
140 }
141
142 // client protocol filters
143 bool clientProtoFilterPickPony(unsigned char** client,
144   unsigned int* client_len, const unsigned char*, unsigned int ) {
145   //the protocol string in length prefixed byte string. the
146   //length byte is not included in the length
147   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
148   *client = p;
149   *client_len = 7;
150   return true;
151 }
152
153 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
154   const unsigned char*, unsigned int) {
155   return false;
156 }
157
158 std::string getFileAsBuf(const char* fileName) {
159   std::string buffer;
160   folly::readFile(fileName, buffer);
161   return buffer;
162 }
163
164 std::string getCommonName(X509* cert) {
165   X509_NAME* subject = X509_get_subject_name(cert);
166   std::string cn;
167   cn.resize(ub_common_name);
168   X509_NAME_get_text_by_NID(
169       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
170   return cn;
171 }
172
173 /**
174  * Test connecting to, writing to, reading from, and closing the
175  * connection to the SSL server.
176  */
177 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
178   // Start listening on a local port
179   WriteCallbackBase writeCallback;
180   ReadCallback readCallback(&writeCallback);
181   HandshakeCallback handshakeCallback(&readCallback);
182   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
183   TestSSLServer server(&acceptCallback);
184
185   // Set up SSL context.
186   std::shared_ptr<SSLContext> sslContext(new SSLContext());
187   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
188   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
189   //sslContext->authenticate(true, false);
190
191   // connect
192   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
193                                                  sslContext);
194   socket->open();
195
196   // write()
197   uint8_t buf[128];
198   memset(buf, 'a', sizeof(buf));
199   socket->write(buf, sizeof(buf));
200
201   // read()
202   uint8_t readbuf[128];
203   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
204   EXPECT_EQ(bytesRead, 128);
205   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
206
207   // close()
208   socket->close();
209
210   cerr << "ConnectWriteReadClose test completed" << endl;
211 }
212
213 /**
214  * Test reading after server close.
215  */
216 TEST(AsyncSSLSocketTest, ReadAfterClose) {
217   // Start listening on a local port
218   WriteCallbackBase writeCallback;
219   ReadEOFCallback readCallback(&writeCallback);
220   HandshakeCallback handshakeCallback(&readCallback);
221   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
222   auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
223
224   // Set up SSL context.
225   auto sslContext = std::make_shared<SSLContext>();
226   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
227
228   auto socket =
229       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
230   socket->open();
231
232   // This should trigger an EOF on the client.
233   auto evb = handshakeCallback.getSocket()->getEventBase();
234   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
235   std::array<uint8_t, 128> readbuf;
236   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
237   EXPECT_EQ(0, bytesRead);
238 }
239
240 /**
241  * Test bad renegotiation
242  */
243 #if !defined(OPENSSL_IS_BORINGSSL)
244 TEST(AsyncSSLSocketTest, Renegotiate) {
245   EventBase eventBase;
246   auto clientCtx = std::make_shared<SSLContext>();
247   auto dfServerCtx = std::make_shared<SSLContext>();
248   std::array<int, 2> fds;
249   getfds(fds.data());
250   getctx(clientCtx, dfServerCtx);
251
252   AsyncSSLSocket::UniquePtr clientSock(
253       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
254   AsyncSSLSocket::UniquePtr serverSock(
255       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
256   SSLHandshakeClient client(std::move(clientSock), true, true);
257   RenegotiatingServer server(std::move(serverSock));
258
259   while (!client.handshakeSuccess_ && !client.handshakeError_) {
260     eventBase.loopOnce();
261   }
262
263   ASSERT_TRUE(client.handshakeSuccess_);
264
265   auto sslSock = std::move(client).moveSocket();
266   sslSock->detachEventBase();
267   // This is nasty, however we don't want to add support for
268   // renegotiation in AsyncSSLSocket.
269   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
270
271   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
272
273   std::thread t([&]() { eventBase.loopForever(); });
274
275   // Trigger the renegotiation.
276   std::array<uint8_t, 128> buf;
277   memset(buf.data(), 'a', buf.size());
278   try {
279     socket->write(buf.data(), buf.size());
280   } catch (AsyncSocketException& e) {
281     LOG(INFO) << "client got error " << e.what();
282   }
283   eventBase.terminateLoopSoon();
284   t.join();
285
286   eventBase.loop();
287   ASSERT_TRUE(server.renegotiationError_);
288 }
289 #endif
290
291 /**
292  * Negative test for handshakeError().
293  */
294 TEST(AsyncSSLSocketTest, HandshakeError) {
295   // Start listening on a local port
296   WriteCallbackBase writeCallback;
297   WriteErrorCallback readCallback(&writeCallback);
298   HandshakeCallback handshakeCallback(&readCallback);
299   HandshakeErrorCallback acceptCallback(&handshakeCallback);
300   TestSSLServer server(&acceptCallback);
301
302   // Set up SSL context.
303   std::shared_ptr<SSLContext> sslContext(new SSLContext());
304   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
305
306   // connect
307   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
308                                                  sslContext);
309   // read()
310   bool ex = false;
311   try {
312     socket->open();
313
314     uint8_t readbuf[128];
315     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
316     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
317   } catch (AsyncSocketException&) {
318     ex = true;
319   }
320   EXPECT_TRUE(ex);
321
322   // close()
323   socket->close();
324   cerr << "HandshakeError test completed" << endl;
325 }
326
327 /**
328  * Negative test for readError().
329  */
330 TEST(AsyncSSLSocketTest, ReadError) {
331   // Start listening on a local port
332   WriteCallbackBase writeCallback;
333   ReadErrorCallback readCallback(&writeCallback);
334   HandshakeCallback handshakeCallback(&readCallback);
335   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
336   TestSSLServer server(&acceptCallback);
337
338   // Set up SSL context.
339   std::shared_ptr<SSLContext> sslContext(new SSLContext());
340   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
341
342   // connect
343   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
344                                                  sslContext);
345   socket->open();
346
347   // write something to trigger ssl handshake
348   uint8_t buf[128];
349   memset(buf, 'a', sizeof(buf));
350   socket->write(buf, sizeof(buf));
351
352   socket->close();
353   cerr << "ReadError test completed" << endl;
354 }
355
356 /**
357  * Negative test for writeError().
358  */
359 TEST(AsyncSSLSocketTest, WriteError) {
360   // Start listening on a local port
361   WriteCallbackBase writeCallback;
362   WriteErrorCallback readCallback(&writeCallback);
363   HandshakeCallback handshakeCallback(&readCallback);
364   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
365   TestSSLServer server(&acceptCallback);
366
367   // Set up SSL context.
368   std::shared_ptr<SSLContext> sslContext(new SSLContext());
369   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
370
371   // connect
372   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
373                                                  sslContext);
374   socket->open();
375
376   // write something to trigger ssl handshake
377   uint8_t buf[128];
378   memset(buf, 'a', sizeof(buf));
379   socket->write(buf, sizeof(buf));
380
381   socket->close();
382   cerr << "WriteError test completed" << endl;
383 }
384
385 /**
386  * Test a socket with TCP_NODELAY unset.
387  */
388 TEST(AsyncSSLSocketTest, SocketWithDelay) {
389   // Start listening on a local port
390   WriteCallbackBase writeCallback;
391   ReadCallback readCallback(&writeCallback);
392   HandshakeCallback handshakeCallback(&readCallback);
393   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
394   TestSSLServer server(&acceptCallback);
395
396   // Set up SSL context.
397   std::shared_ptr<SSLContext> sslContext(new SSLContext());
398   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
399
400   // connect
401   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
402                                                  sslContext);
403   socket->open();
404
405   // write()
406   uint8_t buf[128];
407   memset(buf, 'a', sizeof(buf));
408   socket->write(buf, sizeof(buf));
409
410   // read()
411   uint8_t readbuf[128];
412   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
413   EXPECT_EQ(bytesRead, 128);
414   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
415
416   // close()
417   socket->close();
418
419   cerr << "SocketWithDelay test completed" << endl;
420 }
421
422 using NextProtocolTypePair =
423     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
424
425 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
426   // For matching protos
427  public:
428   void SetUp() override { getctx(clientCtx, serverCtx); }
429
430   void connect(bool unset = false) {
431     getfds(fds);
432
433     if (unset) {
434       // unsetting NPN for any of [client, server] is enough to make NPN not
435       // work
436       clientCtx->unsetNextProtocols();
437     }
438
439     AsyncSSLSocket::UniquePtr clientSock(
440       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
441     AsyncSSLSocket::UniquePtr serverSock(
442       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
443     client = folly::make_unique<NpnClient>(std::move(clientSock));
444     server = folly::make_unique<NpnServer>(std::move(serverSock));
445
446     eventBase.loop();
447   }
448
449   void expectProtocol(const std::string& proto) {
450     EXPECT_NE(client->nextProtoLength, 0);
451     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
452     EXPECT_EQ(
453         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
454         0);
455     string selected((const char*)client->nextProto, client->nextProtoLength);
456     EXPECT_EQ(proto, selected);
457   }
458
459   void expectNoProtocol() {
460     EXPECT_EQ(client->nextProtoLength, 0);
461     EXPECT_EQ(server->nextProtoLength, 0);
462     EXPECT_EQ(client->nextProto, nullptr);
463     EXPECT_EQ(server->nextProto, nullptr);
464   }
465
466   void expectProtocolType() {
467     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
468         GetParam().second == SSLContext::NextProtocolType::ANY) {
469       EXPECT_EQ(client->protocolType, server->protocolType);
470     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
471                GetParam().second == SSLContext::NextProtocolType::ANY) {
472       // Well not much we can say
473     } else {
474       expectProtocolType(GetParam());
475     }
476   }
477
478   void expectProtocolType(NextProtocolTypePair expected) {
479     EXPECT_EQ(client->protocolType, expected.first);
480     EXPECT_EQ(server->protocolType, expected.second);
481   }
482
483   EventBase eventBase;
484   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
485   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
486   int fds[2];
487   std::unique_ptr<NpnClient> client;
488   std::unique_ptr<NpnServer> server;
489 };
490
491 class NextProtocolTLSExtTest : public NextProtocolTest {
492   // For extended TLS protos
493 };
494
495 class NextProtocolNPNOnlyTest : public NextProtocolTest {
496   // For mismatching protos
497 };
498
499 class NextProtocolMismatchTest : public NextProtocolTest {
500   // For mismatching protos
501 };
502
503 TEST_P(NextProtocolTest, NpnTestOverlap) {
504   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
505   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
506                                         GetParam().second);
507
508   connect();
509
510   expectProtocol("baz");
511   expectProtocolType();
512 }
513
514 TEST_P(NextProtocolTest, NpnTestUnset) {
515   // Identical to above test, except that we want unset NPN before
516   // looping.
517   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
518   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
519                                         GetParam().second);
520
521   connect(true /* unset */);
522
523   // if alpn negotiation fails, type will appear as npn
524   expectNoProtocol();
525   EXPECT_EQ(client->protocolType, server->protocolType);
526 }
527
528 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
529   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
530   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
531                                         GetParam().second);
532
533   connect();
534
535   expectNoProtocol();
536   expectProtocolType(
537       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
538 }
539
540 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
541 // will fail on 1.0.2 before that.
542 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
543   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
544   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
545                                         GetParam().second);
546
547   connect();
548
549   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
550       GetParam().second == SSLContext::NextProtocolType::ALPN) {
551     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
552     // mismatch should result in a fatal alert, but this is OpenSSL's current
553     // behavior and we want to know if it changes.
554     expectNoProtocol();
555   }
556 #if defined(OPENSSL_IS_BORINGSSL)
557   // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
558   // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
559   // NPN *and* ALPN
560   else if (
561       GetParam().first == SSLContext::NextProtocolType::ANY &&
562       GetParam().second == SSLContext::NextProtocolType::ANY) {
563     expectNoProtocol();
564   }
565 #endif
566   else {
567     expectProtocol("blub");
568     expectProtocolType(
569         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
570   }
571 }
572
573 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
574   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
575   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
576   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
577                                         GetParam().second);
578
579   connect();
580
581   expectProtocol("ponies");
582   expectProtocolType();
583 }
584
585 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
586   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
587   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
588   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
589                                         GetParam().second);
590
591   connect();
592
593   expectProtocol("blub");
594   expectProtocolType();
595 }
596
597 TEST_P(NextProtocolTest, RandomizedNpnTest) {
598   // Probability that this test will fail is 2^-64, which could be considered
599   // as negligible.
600   const int kTries = 64;
601
602   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
603                                         GetParam().first);
604   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
605                                                   GetParam().second);
606
607   std::set<string> selectedProtocols;
608   for (int i = 0; i < kTries; ++i) {
609     connect();
610
611     EXPECT_NE(client->nextProtoLength, 0);
612     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
613     EXPECT_EQ(
614         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
615         0);
616     string selected((const char*)client->nextProto, client->nextProtoLength);
617     selectedProtocols.insert(selected);
618     expectProtocolType();
619   }
620   EXPECT_EQ(selectedProtocols.size(), 2);
621 }
622
623 INSTANTIATE_TEST_CASE_P(
624     AsyncSSLSocketTest,
625     NextProtocolTest,
626     ::testing::Values(
627         NextProtocolTypePair(
628             SSLContext::NextProtocolType::NPN,
629             SSLContext::NextProtocolType::NPN),
630         NextProtocolTypePair(
631             SSLContext::NextProtocolType::NPN,
632             SSLContext::NextProtocolType::ANY),
633         NextProtocolTypePair(
634             SSLContext::NextProtocolType::ANY,
635             SSLContext::NextProtocolType::ANY)));
636
637 #if FOLLY_OPENSSL_HAS_ALPN
638 INSTANTIATE_TEST_CASE_P(
639     AsyncSSLSocketTest,
640     NextProtocolTLSExtTest,
641     ::testing::Values(
642         NextProtocolTypePair(
643             SSLContext::NextProtocolType::ALPN,
644             SSLContext::NextProtocolType::ALPN),
645         NextProtocolTypePair(
646             SSLContext::NextProtocolType::ALPN,
647             SSLContext::NextProtocolType::ANY),
648         NextProtocolTypePair(
649             SSLContext::NextProtocolType::ANY,
650             SSLContext::NextProtocolType::ALPN)));
651 #endif
652
653 INSTANTIATE_TEST_CASE_P(
654     AsyncSSLSocketTest,
655     NextProtocolNPNOnlyTest,
656     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
657                                            SSLContext::NextProtocolType::NPN)));
658
659 #if FOLLY_OPENSSL_HAS_ALPN
660 INSTANTIATE_TEST_CASE_P(
661     AsyncSSLSocketTest,
662     NextProtocolMismatchTest,
663     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
664                                            SSLContext::NextProtocolType::ALPN),
665                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
666                                            SSLContext::NextProtocolType::NPN)));
667 #endif
668
669 #ifndef OPENSSL_NO_TLSEXT
670 /**
671  * 1. Client sends TLSEXT_HOSTNAME in client hello.
672  * 2. Server found a match SSL_CTX and use this SSL_CTX to
673  *    continue the SSL handshake.
674  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
675  */
676 TEST(AsyncSSLSocketTest, SNITestMatch) {
677   EventBase eventBase;
678   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
679   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
680   // Use the same SSLContext to continue the handshake after
681   // tlsext_hostname match.
682   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
683   const std::string serverName("xyz.newdev.facebook.com");
684   int fds[2];
685   getfds(fds);
686   getctx(clientCtx, dfServerCtx);
687
688   AsyncSSLSocket::UniquePtr clientSock(
689     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
690   AsyncSSLSocket::UniquePtr serverSock(
691     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
692   SNIClient client(std::move(clientSock));
693   SNIServer server(std::move(serverSock),
694                    dfServerCtx,
695                    hskServerCtx,
696                    serverName);
697
698   eventBase.loop();
699
700   EXPECT_TRUE(client.serverNameMatch);
701   EXPECT_TRUE(server.serverNameMatch);
702 }
703
704 /**
705  * 1. Client sends TLSEXT_HOSTNAME in client hello.
706  * 2. Server cannot find a matching SSL_CTX and continue to use
707  *    the current SSL_CTX to do the handshake.
708  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
709  */
710 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
711   EventBase eventBase;
712   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
713   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
714   // Use the same SSLContext to continue the handshake after
715   // tlsext_hostname match.
716   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
717   const std::string clientRequestingServerName("foo.com");
718   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
719
720   int fds[2];
721   getfds(fds);
722   getctx(clientCtx, dfServerCtx);
723
724   AsyncSSLSocket::UniquePtr clientSock(
725     new AsyncSSLSocket(clientCtx,
726                         &eventBase,
727                         fds[0],
728                         clientRequestingServerName));
729   AsyncSSLSocket::UniquePtr serverSock(
730     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
731   SNIClient client(std::move(clientSock));
732   SNIServer server(std::move(serverSock),
733                    dfServerCtx,
734                    hskServerCtx,
735                    serverExpectedServerName);
736
737   eventBase.loop();
738
739   EXPECT_TRUE(!client.serverNameMatch);
740   EXPECT_TRUE(!server.serverNameMatch);
741 }
742 /**
743  * 1. Client sends TLSEXT_HOSTNAME in client hello.
744  * 2. We then change the serverName.
745  * 3. We expect that we get 'false' as the result for serNameMatch.
746  */
747
748 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
749    EventBase eventBase;
750   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
751   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
752   // Use the same SSLContext to continue the handshake after
753   // tlsext_hostname match.
754   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
755   const std::string serverName("xyz.newdev.facebook.com");
756   int fds[2];
757   getfds(fds);
758   getctx(clientCtx, dfServerCtx);
759
760   AsyncSSLSocket::UniquePtr clientSock(
761     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
762   //Change the server name
763   std::string newName("new.com");
764   clientSock->setServerName(newName);
765   AsyncSSLSocket::UniquePtr serverSock(
766     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
767   SNIClient client(std::move(clientSock));
768   SNIServer server(std::move(serverSock),
769                    dfServerCtx,
770                    hskServerCtx,
771                    serverName);
772
773   eventBase.loop();
774
775   EXPECT_TRUE(!client.serverNameMatch);
776 }
777
778 /**
779  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
780  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
781  */
782 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
783   EventBase eventBase;
784   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
785   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
786   // Use the same SSLContext to continue the handshake after
787   // tlsext_hostname match.
788   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
789   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
790
791   int fds[2];
792   getfds(fds);
793   getctx(clientCtx, dfServerCtx);
794
795   AsyncSSLSocket::UniquePtr clientSock(
796     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
797   AsyncSSLSocket::UniquePtr serverSock(
798     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
799   SNIClient client(std::move(clientSock));
800   SNIServer server(std::move(serverSock),
801                    dfServerCtx,
802                    hskServerCtx,
803                    serverExpectedServerName);
804
805   eventBase.loop();
806
807   EXPECT_TRUE(!client.serverNameMatch);
808   EXPECT_TRUE(!server.serverNameMatch);
809 }
810
811 #endif
812 /**
813  * Test SSL client socket
814  */
815 TEST(AsyncSSLSocketTest, SSLClientTest) {
816   // Start listening on a local port
817   WriteCallbackBase writeCallback;
818   ReadCallback readCallback(&writeCallback);
819   HandshakeCallback handshakeCallback(&readCallback);
820   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
821   TestSSLServer server(&acceptCallback);
822
823   // Set up SSL client
824   EventBase eventBase;
825   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
826
827   client->connect();
828   EventBaseAborter eba(&eventBase, 3000);
829   eventBase.loop();
830
831   EXPECT_EQ(client->getMiss(), 1);
832   EXPECT_EQ(client->getHit(), 0);
833
834   cerr << "SSLClientTest test completed" << endl;
835 }
836
837
838 /**
839  * Test SSL client socket session re-use
840  */
841 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
842   // Start listening on a local port
843   WriteCallbackBase writeCallback;
844   ReadCallback readCallback(&writeCallback);
845   HandshakeCallback handshakeCallback(&readCallback);
846   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
847   TestSSLServer server(&acceptCallback);
848
849   // Set up SSL client
850   EventBase eventBase;
851   auto client =
852       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
853
854   client->connect();
855   EventBaseAborter eba(&eventBase, 3000);
856   eventBase.loop();
857
858   EXPECT_EQ(client->getMiss(), 1);
859   EXPECT_EQ(client->getHit(), 9);
860
861   cerr << "SSLClientTestReuse test completed" << endl;
862 }
863
864 /**
865  * Test SSL client socket timeout
866  */
867 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
868   // Start listening on a local port
869   EmptyReadCallback readCallback;
870   HandshakeCallback handshakeCallback(&readCallback,
871                                       HandshakeCallback::EXPECT_ERROR);
872   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
873   TestSSLServer server(&acceptCallback);
874
875   // Set up SSL client
876   EventBase eventBase;
877   auto client =
878       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
879   client->connect(true /* write before connect completes */);
880   EventBaseAborter eba(&eventBase, 3000);
881   eventBase.loop();
882
883   usleep(100000);
884   // This is checking that the connectError callback precedes any queued
885   // writeError callbacks.  This matches AsyncSocket's behavior
886   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
887   EXPECT_EQ(client->getErrors(), 1);
888   EXPECT_EQ(client->getMiss(), 0);
889   EXPECT_EQ(client->getHit(), 0);
890
891   cerr << "SSLClientTimeoutTest test completed" << endl;
892 }
893
894 // The next 3 tests need an FB-only extension, and will fail without it
895 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
896 /**
897  * Test SSL server async cache
898  */
899 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
900   // Start listening on a local port
901   WriteCallbackBase writeCallback;
902   ReadCallback readCallback(&writeCallback);
903   HandshakeCallback handshakeCallback(&readCallback);
904   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
905   TestSSLAsyncCacheServer server(&acceptCallback);
906
907   // Set up SSL client
908   EventBase eventBase;
909   auto client =
910       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
911
912   client->connect();
913   EventBaseAborter eba(&eventBase, 3000);
914   eventBase.loop();
915
916   EXPECT_EQ(server.getAsyncCallbacks(), 18);
917   EXPECT_EQ(server.getAsyncLookups(), 9);
918   EXPECT_EQ(client->getMiss(), 10);
919   EXPECT_EQ(client->getHit(), 0);
920
921   cerr << "SSLServerAsyncCacheTest test completed" << endl;
922 }
923
924 /**
925  * Test SSL server accept timeout with cache path
926  */
927 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
928   // Start listening on a local port
929   WriteCallbackBase writeCallback;
930   ReadCallback readCallback(&writeCallback);
931   HandshakeCallback handshakeCallback(&readCallback);
932   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
933   TestSSLAsyncCacheServer server(&acceptCallback);
934
935   // Set up SSL client
936   EventBase eventBase;
937   // only do a TCP connect
938   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
939   sock->connect(nullptr, server.getAddress());
940
941   EmptyReadCallback clientReadCallback;
942   clientReadCallback.tcpSocket_ = sock;
943   sock->setReadCB(&clientReadCallback);
944
945   EventBaseAborter eba(&eventBase, 3000);
946   eventBase.loop();
947
948   EXPECT_EQ(readCallback.state, STATE_WAITING);
949
950   cerr << "SSLServerTimeoutTest test completed" << endl;
951 }
952
953 /**
954  * Test SSL server accept timeout with cache path
955  */
956 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
957   // Start listening on a local port
958   WriteCallbackBase writeCallback;
959   ReadCallback readCallback(&writeCallback);
960   HandshakeCallback handshakeCallback(&readCallback);
961   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
962   TestSSLAsyncCacheServer server(&acceptCallback);
963
964   // Set up SSL client
965   EventBase eventBase;
966   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
967
968   client->connect();
969   EventBaseAborter eba(&eventBase, 3000);
970   eventBase.loop();
971
972   EXPECT_EQ(server.getAsyncCallbacks(), 1);
973   EXPECT_EQ(server.getAsyncLookups(), 1);
974   EXPECT_EQ(client->getErrors(), 1);
975   EXPECT_EQ(client->getMiss(), 1);
976   EXPECT_EQ(client->getHit(), 0);
977
978   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
979 }
980
981 /**
982  * Test SSL server accept timeout with cache path
983  */
984 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
985   // Start listening on a local port
986   WriteCallbackBase writeCallback;
987   ReadCallback readCallback(&writeCallback);
988   HandshakeCallback handshakeCallback(&readCallback,
989                                       HandshakeCallback::EXPECT_ERROR);
990   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
991   TestSSLAsyncCacheServer server(&acceptCallback, 500);
992
993   // Set up SSL client
994   EventBase eventBase;
995   auto client =
996       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
997
998   client->connect();
999   EventBaseAborter eba(&eventBase, 3000);
1000   eventBase.loop();
1001
1002   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
1003       handshakeCallback.closeSocket();});
1004   // give time for the cache lookup to come back and find it closed
1005   handshakeCallback.waitForHandshake();
1006
1007   EXPECT_EQ(server.getAsyncCallbacks(), 1);
1008   EXPECT_EQ(server.getAsyncLookups(), 1);
1009   EXPECT_EQ(client->getErrors(), 1);
1010   EXPECT_EQ(client->getMiss(), 1);
1011   EXPECT_EQ(client->getHit(), 0);
1012
1013   cerr << "SSLServerCacheCloseTest test completed" << endl;
1014 }
1015 #endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
1016
1017 /**
1018  * Verify Client Ciphers obtained using SSL MSG Callback.
1019  */
1020 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1021   EventBase eventBase;
1022   auto clientCtx = std::make_shared<SSLContext>();
1023   auto serverCtx = std::make_shared<SSLContext>();
1024   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1025   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1026   serverCtx->loadPrivateKey(testKey);
1027   serverCtx->loadCertificate(testCert);
1028   serverCtx->loadTrustedCertificates(testCA);
1029   serverCtx->loadClientCAList(testCA);
1030
1031   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1032   clientCtx->ciphers("AES256-SHA:AES128-SHA");
1033   clientCtx->loadPrivateKey(testKey);
1034   clientCtx->loadCertificate(testCert);
1035   clientCtx->loadTrustedCertificates(testCA);
1036
1037   int fds[2];
1038   getfds(fds);
1039
1040   AsyncSSLSocket::UniquePtr clientSock(
1041       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1042   AsyncSSLSocket::UniquePtr serverSock(
1043       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1044
1045   SSLHandshakeClient client(std::move(clientSock), true, true);
1046   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1047
1048   eventBase.loop();
1049
1050 #if defined(OPENSSL_IS_BORINGSSL)
1051   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
1052 #else
1053   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
1054 #endif
1055   EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1056   EXPECT_TRUE(client.handshakeVerify_);
1057   EXPECT_TRUE(client.handshakeSuccess_);
1058   EXPECT_TRUE(!client.handshakeError_);
1059   EXPECT_TRUE(server.handshakeVerify_);
1060   EXPECT_TRUE(server.handshakeSuccess_);
1061   EXPECT_TRUE(!server.handshakeError_);
1062 }
1063
1064 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1065   EventBase eventBase;
1066   auto ctx = std::make_shared<SSLContext>();
1067
1068   int fds[2];
1069   getfds(fds);
1070
1071   int bufLen = 42;
1072   uint8_t majorVersion = 18;
1073   uint8_t minorVersion = 25;
1074
1075   // Create callback buf
1076   auto buf = IOBuf::create(bufLen);
1077   buf->append(bufLen);
1078   folly::io::RWPrivateCursor cursor(buf.get());
1079   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1080   cursor.write<uint16_t>(0);
1081   cursor.write<uint8_t>(38);
1082   cursor.write<uint8_t>(majorVersion);
1083   cursor.write<uint8_t>(minorVersion);
1084   cursor.skip(32);
1085   cursor.write<uint32_t>(0);
1086
1087   SSL* ssl = ctx->createSSL();
1088   SCOPE_EXIT { SSL_free(ssl); };
1089   AsyncSSLSocket::UniquePtr sock(
1090       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1091   sock->enableClientHelloParsing();
1092
1093   // Test client hello parsing in one packet
1094   AsyncSSLSocket::clientHelloParsingCallback(
1095       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1096   buf.reset();
1097
1098   auto parsedClientHello = sock->getClientHelloInfo();
1099   EXPECT_TRUE(parsedClientHello != nullptr);
1100   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1101   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1102 }
1103
1104 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1105   EventBase eventBase;
1106   auto ctx = std::make_shared<SSLContext>();
1107
1108   int fds[2];
1109   getfds(fds);
1110
1111   int bufLen = 42;
1112   uint8_t majorVersion = 18;
1113   uint8_t minorVersion = 25;
1114
1115   // Create callback buf
1116   auto buf = IOBuf::create(bufLen);
1117   buf->append(bufLen);
1118   folly::io::RWPrivateCursor cursor(buf.get());
1119   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1120   cursor.write<uint16_t>(0);
1121   cursor.write<uint8_t>(38);
1122   cursor.write<uint8_t>(majorVersion);
1123   cursor.write<uint8_t>(minorVersion);
1124   cursor.skip(32);
1125   cursor.write<uint32_t>(0);
1126
1127   SSL* ssl = ctx->createSSL();
1128   SCOPE_EXIT { SSL_free(ssl); };
1129   AsyncSSLSocket::UniquePtr sock(
1130       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1131   sock->enableClientHelloParsing();
1132
1133   // Test parsing with two packets with first packet size < 3
1134   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1135   AsyncSSLSocket::clientHelloParsingCallback(
1136       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1137       ssl, sock.get());
1138   bufCopy.reset();
1139   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1140   AsyncSSLSocket::clientHelloParsingCallback(
1141       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1142       ssl, sock.get());
1143   bufCopy.reset();
1144
1145   auto parsedClientHello = sock->getClientHelloInfo();
1146   EXPECT_TRUE(parsedClientHello != nullptr);
1147   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1148   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1149 }
1150
1151 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1152   EventBase eventBase;
1153   auto ctx = std::make_shared<SSLContext>();
1154
1155   int fds[2];
1156   getfds(fds);
1157
1158   int bufLen = 42;
1159   uint8_t majorVersion = 18;
1160   uint8_t minorVersion = 25;
1161
1162   // Create callback buf
1163   auto buf = IOBuf::create(bufLen);
1164   buf->append(bufLen);
1165   folly::io::RWPrivateCursor cursor(buf.get());
1166   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1167   cursor.write<uint16_t>(0);
1168   cursor.write<uint8_t>(38);
1169   cursor.write<uint8_t>(majorVersion);
1170   cursor.write<uint8_t>(minorVersion);
1171   cursor.skip(32);
1172   cursor.write<uint32_t>(0);
1173
1174   SSL* ssl = ctx->createSSL();
1175   SCOPE_EXIT { SSL_free(ssl); };
1176   AsyncSSLSocket::UniquePtr sock(
1177       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1178   sock->enableClientHelloParsing();
1179
1180   // Test parsing with multiple small packets
1181   for (uint64_t i = 0; i < buf->length(); i += 3) {
1182     auto bufCopy = folly::IOBuf::copyBuffer(
1183         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1184     AsyncSSLSocket::clientHelloParsingCallback(
1185         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1186         ssl, sock.get());
1187     bufCopy.reset();
1188   }
1189
1190   auto parsedClientHello = sock->getClientHelloInfo();
1191   EXPECT_TRUE(parsedClientHello != nullptr);
1192   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1193   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1194 }
1195
1196 /**
1197  * Verify sucessful behavior of SSL certificate validation.
1198  */
1199 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1200   EventBase eventBase;
1201   auto clientCtx = std::make_shared<SSLContext>();
1202   auto dfServerCtx = std::make_shared<SSLContext>();
1203
1204   int fds[2];
1205   getfds(fds);
1206   getctx(clientCtx, dfServerCtx);
1207
1208   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1209   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1210
1211   AsyncSSLSocket::UniquePtr clientSock(
1212     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1213   AsyncSSLSocket::UniquePtr serverSock(
1214     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1215
1216   SSLHandshakeClient client(std::move(clientSock), true, true);
1217   clientCtx->loadTrustedCertificates(testCA);
1218
1219   SSLHandshakeServer server(std::move(serverSock), true, true);
1220
1221   eventBase.loop();
1222
1223   EXPECT_TRUE(client.handshakeVerify_);
1224   EXPECT_TRUE(client.handshakeSuccess_);
1225   EXPECT_TRUE(!client.handshakeError_);
1226   EXPECT_LE(0, client.handshakeTime.count());
1227   EXPECT_TRUE(!server.handshakeVerify_);
1228   EXPECT_TRUE(server.handshakeSuccess_);
1229   EXPECT_TRUE(!server.handshakeError_);
1230   EXPECT_LE(0, server.handshakeTime.count());
1231 }
1232
1233 /**
1234  * Verify that the client's verification callback is able to fail SSL
1235  * connection establishment.
1236  */
1237 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1238   EventBase eventBase;
1239   auto clientCtx = std::make_shared<SSLContext>();
1240   auto dfServerCtx = std::make_shared<SSLContext>();
1241
1242   int fds[2];
1243   getfds(fds);
1244   getctx(clientCtx, dfServerCtx);
1245
1246   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1247   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1248
1249   AsyncSSLSocket::UniquePtr clientSock(
1250     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1251   AsyncSSLSocket::UniquePtr serverSock(
1252     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1253
1254   SSLHandshakeClient client(std::move(clientSock), true, false);
1255   clientCtx->loadTrustedCertificates(testCA);
1256
1257   SSLHandshakeServer server(std::move(serverSock), true, true);
1258
1259   eventBase.loop();
1260
1261   EXPECT_TRUE(client.handshakeVerify_);
1262   EXPECT_TRUE(!client.handshakeSuccess_);
1263   EXPECT_TRUE(client.handshakeError_);
1264   EXPECT_LE(0, client.handshakeTime.count());
1265   EXPECT_TRUE(!server.handshakeVerify_);
1266   EXPECT_TRUE(!server.handshakeSuccess_);
1267   EXPECT_TRUE(server.handshakeError_);
1268   EXPECT_LE(0, server.handshakeTime.count());
1269 }
1270
1271 /**
1272  * Verify that the options in SSLContext can be overridden in
1273  * sslConnect/Accept.i.e specifying that no validation should be performed
1274  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1275  * the validation callback.
1276  */
1277 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1278   EventBase eventBase;
1279   auto clientCtx = std::make_shared<SSLContext>();
1280   auto dfServerCtx = std::make_shared<SSLContext>();
1281
1282   int fds[2];
1283   getfds(fds);
1284   getctx(clientCtx, dfServerCtx);
1285
1286   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1287   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1288
1289   AsyncSSLSocket::UniquePtr clientSock(
1290     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1291   AsyncSSLSocket::UniquePtr serverSock(
1292     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1293
1294   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1295   clientCtx->loadTrustedCertificates(testCA);
1296
1297   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1298
1299   eventBase.loop();
1300
1301   EXPECT_TRUE(!client.handshakeVerify_);
1302   EXPECT_TRUE(client.handshakeSuccess_);
1303   EXPECT_TRUE(!client.handshakeError_);
1304   EXPECT_LE(0, client.handshakeTime.count());
1305   EXPECT_TRUE(!server.handshakeVerify_);
1306   EXPECT_TRUE(server.handshakeSuccess_);
1307   EXPECT_TRUE(!server.handshakeError_);
1308   EXPECT_LE(0, server.handshakeTime.count());
1309 }
1310
1311 /**
1312  * Verify that the options in SSLContext can be overridden in
1313  * sslConnect/Accept. Enable verification even if context says otherwise.
1314  * Test requireClientCert with client cert
1315  */
1316 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1317   EventBase eventBase;
1318   auto clientCtx = std::make_shared<SSLContext>();
1319   auto serverCtx = std::make_shared<SSLContext>();
1320   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1321   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1322   serverCtx->loadPrivateKey(testKey);
1323   serverCtx->loadCertificate(testCert);
1324   serverCtx->loadTrustedCertificates(testCA);
1325   serverCtx->loadClientCAList(testCA);
1326
1327   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1328   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1329   clientCtx->loadPrivateKey(testKey);
1330   clientCtx->loadCertificate(testCert);
1331   clientCtx->loadTrustedCertificates(testCA);
1332
1333   int fds[2];
1334   getfds(fds);
1335
1336   AsyncSSLSocket::UniquePtr clientSock(
1337       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1338   AsyncSSLSocket::UniquePtr serverSock(
1339       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1340
1341   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1342   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1343
1344   eventBase.loop();
1345
1346   EXPECT_TRUE(client.handshakeVerify_);
1347   EXPECT_TRUE(client.handshakeSuccess_);
1348   EXPECT_FALSE(client.handshakeError_);
1349   EXPECT_LE(0, client.handshakeTime.count());
1350   EXPECT_TRUE(server.handshakeVerify_);
1351   EXPECT_TRUE(server.handshakeSuccess_);
1352   EXPECT_FALSE(server.handshakeError_);
1353   EXPECT_LE(0, server.handshakeTime.count());
1354 }
1355
1356 /**
1357  * Verify that the client's verification callback is able to override
1358  * the preverification failure and allow a successful connection.
1359  */
1360 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1361   EventBase eventBase;
1362   auto clientCtx = std::make_shared<SSLContext>();
1363   auto dfServerCtx = std::make_shared<SSLContext>();
1364
1365   int fds[2];
1366   getfds(fds);
1367   getctx(clientCtx, dfServerCtx);
1368
1369   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1370   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1371
1372   AsyncSSLSocket::UniquePtr clientSock(
1373     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1374   AsyncSSLSocket::UniquePtr serverSock(
1375     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1376
1377   SSLHandshakeClient client(std::move(clientSock), false, true);
1378   SSLHandshakeServer server(std::move(serverSock), true, true);
1379
1380   eventBase.loop();
1381
1382   EXPECT_TRUE(client.handshakeVerify_);
1383   EXPECT_TRUE(client.handshakeSuccess_);
1384   EXPECT_TRUE(!client.handshakeError_);
1385   EXPECT_LE(0, client.handshakeTime.count());
1386   EXPECT_TRUE(!server.handshakeVerify_);
1387   EXPECT_TRUE(server.handshakeSuccess_);
1388   EXPECT_TRUE(!server.handshakeError_);
1389   EXPECT_LE(0, server.handshakeTime.count());
1390 }
1391
1392 /**
1393  * Verify that specifying that no validation should be performed allows an
1394  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1395  * callback.
1396  */
1397 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1398   EventBase eventBase;
1399   auto clientCtx = std::make_shared<SSLContext>();
1400   auto dfServerCtx = std::make_shared<SSLContext>();
1401
1402   int fds[2];
1403   getfds(fds);
1404   getctx(clientCtx, dfServerCtx);
1405
1406   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1407   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1408
1409   AsyncSSLSocket::UniquePtr clientSock(
1410     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1411   AsyncSSLSocket::UniquePtr serverSock(
1412     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1413
1414   SSLHandshakeClient client(std::move(clientSock), false, false);
1415   SSLHandshakeServer server(std::move(serverSock), false, false);
1416
1417   eventBase.loop();
1418
1419   EXPECT_TRUE(!client.handshakeVerify_);
1420   EXPECT_TRUE(client.handshakeSuccess_);
1421   EXPECT_TRUE(!client.handshakeError_);
1422   EXPECT_LE(0, client.handshakeTime.count());
1423   EXPECT_TRUE(!server.handshakeVerify_);
1424   EXPECT_TRUE(server.handshakeSuccess_);
1425   EXPECT_TRUE(!server.handshakeError_);
1426   EXPECT_LE(0, server.handshakeTime.count());
1427 }
1428
1429 /**
1430  * Test requireClientCert with client cert
1431  */
1432 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1433   EventBase eventBase;
1434   auto clientCtx = std::make_shared<SSLContext>();
1435   auto serverCtx = std::make_shared<SSLContext>();
1436   serverCtx->setVerificationOption(
1437       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1438   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1439   serverCtx->loadPrivateKey(testKey);
1440   serverCtx->loadCertificate(testCert);
1441   serverCtx->loadTrustedCertificates(testCA);
1442   serverCtx->loadClientCAList(testCA);
1443
1444   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1445   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1446   clientCtx->loadPrivateKey(testKey);
1447   clientCtx->loadCertificate(testCert);
1448   clientCtx->loadTrustedCertificates(testCA);
1449
1450   int fds[2];
1451   getfds(fds);
1452
1453   AsyncSSLSocket::UniquePtr clientSock(
1454       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1455   AsyncSSLSocket::UniquePtr serverSock(
1456       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1457
1458   SSLHandshakeClient client(std::move(clientSock), true, true);
1459   SSLHandshakeServer server(std::move(serverSock), true, true);
1460
1461   eventBase.loop();
1462
1463   EXPECT_TRUE(client.handshakeVerify_);
1464   EXPECT_TRUE(client.handshakeSuccess_);
1465   EXPECT_FALSE(client.handshakeError_);
1466   EXPECT_LE(0, client.handshakeTime.count());
1467   EXPECT_TRUE(server.handshakeVerify_);
1468   EXPECT_TRUE(server.handshakeSuccess_);
1469   EXPECT_FALSE(server.handshakeError_);
1470   EXPECT_LE(0, server.handshakeTime.count());
1471 }
1472
1473
1474 /**
1475  * Test requireClientCert with no client cert
1476  */
1477 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1478   EventBase eventBase;
1479   auto clientCtx = std::make_shared<SSLContext>();
1480   auto serverCtx = std::make_shared<SSLContext>();
1481   serverCtx->setVerificationOption(
1482       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1483   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1484   serverCtx->loadPrivateKey(testKey);
1485   serverCtx->loadCertificate(testCert);
1486   serverCtx->loadTrustedCertificates(testCA);
1487   serverCtx->loadClientCAList(testCA);
1488   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1489   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1490
1491   int fds[2];
1492   getfds(fds);
1493
1494   AsyncSSLSocket::UniquePtr clientSock(
1495       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1496   AsyncSSLSocket::UniquePtr serverSock(
1497       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1498
1499   SSLHandshakeClient client(std::move(clientSock), false, false);
1500   SSLHandshakeServer server(std::move(serverSock), false, false);
1501
1502   eventBase.loop();
1503
1504   EXPECT_FALSE(server.handshakeVerify_);
1505   EXPECT_FALSE(server.handshakeSuccess_);
1506   EXPECT_TRUE(server.handshakeError_);
1507   EXPECT_LE(0, client.handshakeTime.count());
1508   EXPECT_LE(0, server.handshakeTime.count());
1509 }
1510
1511 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1512   auto cert = getFileAsBuf(testCert);
1513   auto key = getFileAsBuf(testKey);
1514
1515   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1516   BIO_write(certBio.get(), cert.data(), cert.size());
1517   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1518   BIO_write(keyBio.get(), key.data(), key.size());
1519
1520   // Create SSL structs from buffers to get properties
1521   ssl::X509UniquePtr certStruct(
1522       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1523   ssl::EvpPkeyUniquePtr keyStruct(
1524       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1525   certBio = nullptr;
1526   keyBio = nullptr;
1527
1528   auto origCommonName = getCommonName(certStruct.get());
1529   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1530   certStruct = nullptr;
1531   keyStruct = nullptr;
1532
1533   auto ctx = std::make_shared<SSLContext>();
1534   ctx->loadPrivateKeyFromBufferPEM(key);
1535   ctx->loadCertificateFromBufferPEM(cert);
1536   ctx->loadTrustedCertificates(testCA);
1537
1538   ssl::SSLUniquePtr ssl(ctx->createSSL());
1539
1540   auto newCert = SSL_get_certificate(ssl.get());
1541   auto newKey = SSL_get_privatekey(ssl.get());
1542
1543   // Get properties from SSL struct
1544   auto newCommonName = getCommonName(newCert);
1545   auto newKeySize = EVP_PKEY_bits(newKey);
1546
1547   // Check that the key and cert have the expected properties
1548   EXPECT_EQ(origCommonName, newCommonName);
1549   EXPECT_EQ(origKeySize, newKeySize);
1550 }
1551
1552 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1553   EventBase eb;
1554
1555   // Set up SSL context.
1556   auto sslContext = std::make_shared<SSLContext>();
1557   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1558
1559   // create SSL socket
1560   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1561
1562   EXPECT_EQ(1500, socket->getMinWriteSize());
1563
1564   socket->setMinWriteSize(0);
1565   EXPECT_EQ(0, socket->getMinWriteSize());
1566   socket->setMinWriteSize(50000);
1567   EXPECT_EQ(50000, socket->getMinWriteSize());
1568 }
1569
1570 class ReadCallbackTerminator : public ReadCallback {
1571  public:
1572   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1573       : ReadCallback(wcb)
1574       , base_(base) {}
1575
1576   // Do not write data back, terminate the loop.
1577   void readDataAvailable(size_t len) noexcept override {
1578     std::cerr << "readDataAvailable, len " << len << std::endl;
1579
1580     currentBuffer.length = len;
1581
1582     buffers.push_back(currentBuffer);
1583     currentBuffer.reset();
1584     state = STATE_SUCCEEDED;
1585
1586     socket_->setReadCB(nullptr);
1587     base_->terminateLoopSoon();
1588   }
1589  private:
1590   EventBase* base_;
1591 };
1592
1593
1594 /**
1595  * Test a full unencrypted codepath
1596  */
1597 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1598   EventBase base;
1599
1600   auto clientCtx = std::make_shared<folly::SSLContext>();
1601   auto serverCtx = std::make_shared<folly::SSLContext>();
1602   int fds[2];
1603   getfds(fds);
1604   getctx(clientCtx, serverCtx);
1605   auto client = AsyncSSLSocket::newSocket(
1606                   clientCtx, &base, fds[0], false, true);
1607   auto server = AsyncSSLSocket::newSocket(
1608                   serverCtx, &base, fds[1], true, true);
1609
1610   ReadCallbackTerminator readCallback(&base, nullptr);
1611   server->setReadCB(&readCallback);
1612   readCallback.setSocket(server);
1613
1614   uint8_t buf[128];
1615   memset(buf, 'a', sizeof(buf));
1616   client->write(nullptr, buf, sizeof(buf));
1617
1618   // Check that bytes are unencrypted
1619   char c;
1620   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1621   EXPECT_EQ('a', c);
1622
1623   EventBaseAborter eba(&base, 3000);
1624   base.loop();
1625
1626   EXPECT_EQ(1, readCallback.buffers.size());
1627   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1628
1629   server->setReadCB(&readCallback);
1630
1631   // Unencrypted
1632   server->sslAccept(nullptr);
1633   client->sslConn(nullptr);
1634
1635   // Do NOT wait for handshake, writing should be queued and happen after
1636
1637   client->write(nullptr, buf, sizeof(buf));
1638
1639   // Check that bytes are *not* unencrypted
1640   char c2;
1641   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1642   EXPECT_NE('a', c2);
1643
1644
1645   base.loop();
1646
1647   EXPECT_EQ(2, readCallback.buffers.size());
1648   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1649 }
1650
1651 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1652   // Start listening on a local port
1653   WriteCallbackBase writeCallback;
1654   WriteErrorCallback readCallback(&writeCallback);
1655   HandshakeCallback handshakeCallback(&readCallback,
1656                                       HandshakeCallback::EXPECT_ERROR);
1657   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1658   TestSSLServer server(&acceptCallback);
1659
1660   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1661   socket->open();
1662   uint8_t buf[3] = {0x16, 0x03, 0x01};
1663   socket->write(buf, sizeof(buf));
1664   socket->closeWithReset();
1665
1666   handshakeCallback.waitForHandshake();
1667   EXPECT_NE(
1668       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1669   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1670 }
1671
1672 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1673   // Start listening on a local port
1674   WriteCallbackBase writeCallback;
1675   WriteErrorCallback readCallback(&writeCallback);
1676   HandshakeCallback handshakeCallback(&readCallback,
1677                                       HandshakeCallback::EXPECT_ERROR);
1678   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1679   TestSSLServer server(&acceptCallback);
1680
1681   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1682   socket->open();
1683   uint8_t buf[3] = {0x16, 0x03, 0x01};
1684   socket->write(buf, sizeof(buf));
1685   socket->close();
1686
1687   handshakeCallback.waitForHandshake();
1688   EXPECT_NE(
1689       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1690   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1691 }
1692
1693 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1694   // Start listening on a local port
1695   WriteCallbackBase writeCallback;
1696   WriteErrorCallback readCallback(&writeCallback);
1697   HandshakeCallback handshakeCallback(&readCallback,
1698                                       HandshakeCallback::EXPECT_ERROR);
1699   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1700   TestSSLServer server(&acceptCallback);
1701
1702   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1703   socket->open();
1704   uint8_t buf[256] = {0x16, 0x03};
1705   memset(buf + 2, 'a', sizeof(buf) - 2);
1706   socket->write(buf, sizeof(buf));
1707   socket->close();
1708
1709   handshakeCallback.waitForHandshake();
1710   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1711             std::string::npos);
1712 #if defined(OPENSSL_IS_BORINGSSL)
1713   EXPECT_NE(
1714       handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
1715       std::string::npos);
1716 #else
1717   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1718             std::string::npos);
1719 #endif
1720 }
1721
1722 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
1723   using folly::ssl::OpenSSLUtils;
1724   EXPECT_EQ(
1725       OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
1726   // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
1727   EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
1728   // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
1729   EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
1730 }
1731
1732 #if FOLLY_ALLOW_TFO
1733
1734 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1735  public:
1736   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1737
1738   explicit MockAsyncTFOSSLSocket(
1739       std::shared_ptr<folly::SSLContext> sslCtx,
1740       EventBase* evb)
1741       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1742
1743   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1744 };
1745
1746 /**
1747  * Test connecting to, writing to, reading from, and closing the
1748  * connection to the SSL server with TFO.
1749  */
1750 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1751   // Start listening on a local port
1752   WriteCallbackBase writeCallback;
1753   ReadCallback readCallback(&writeCallback);
1754   HandshakeCallback handshakeCallback(&readCallback);
1755   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1756   TestSSLServer server(&acceptCallback, true);
1757
1758   // Set up SSL context.
1759   auto sslContext = std::make_shared<SSLContext>();
1760
1761   // connect
1762   auto socket =
1763       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1764   socket->enableTFO();
1765   socket->open();
1766
1767   // write()
1768   std::array<uint8_t, 128> buf;
1769   memset(buf.data(), 'a', buf.size());
1770   socket->write(buf.data(), buf.size());
1771
1772   // read()
1773   std::array<uint8_t, 128> readbuf;
1774   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1775   EXPECT_EQ(bytesRead, 128);
1776   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1777
1778   // close()
1779   socket->close();
1780 }
1781
1782 /**
1783  * Test connecting to, writing to, reading from, and closing the
1784  * connection to the SSL server with TFO.
1785  */
1786 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1787   // Start listening on a local port
1788   WriteCallbackBase writeCallback;
1789   ReadCallback readCallback(&writeCallback);
1790   HandshakeCallback handshakeCallback(&readCallback);
1791   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1792   TestSSLServer server(&acceptCallback, false);
1793
1794   // Set up SSL context.
1795   auto sslContext = std::make_shared<SSLContext>();
1796
1797   // connect
1798   auto socket =
1799       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1800   socket->enableTFO();
1801   socket->open();
1802
1803   // write()
1804   std::array<uint8_t, 128> buf;
1805   memset(buf.data(), 'a', buf.size());
1806   socket->write(buf.data(), buf.size());
1807
1808   // read()
1809   std::array<uint8_t, 128> readbuf;
1810   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1811   EXPECT_EQ(bytesRead, 128);
1812   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1813
1814   // close()
1815   socket->close();
1816 }
1817
1818 class ConnCallback : public AsyncSocket::ConnectCallback {
1819  public:
1820   virtual void connectSuccess() noexcept override {
1821     state = State::SUCCESS;
1822   }
1823
1824   virtual void connectErr(const AsyncSocketException& ex) noexcept override {
1825     state = State::ERROR;
1826     error = ex.what();
1827   }
1828
1829   enum class State { WAITING, SUCCESS, ERROR };
1830
1831   State state{State::WAITING};
1832   std::string error;
1833 };
1834
1835 template <class Cardinality>
1836 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1837     EventBase* evb,
1838     const SocketAddress& address,
1839     Cardinality cardinality) {
1840   // Set up SSL context.
1841   auto sslContext = std::make_shared<SSLContext>();
1842
1843   // connect
1844   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1845       new MockAsyncTFOSSLSocket(sslContext, evb));
1846   socket->enableTFO();
1847
1848   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1849       .Times(cardinality)
1850       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1851         sockaddr_storage addr;
1852         auto len = address.getAddress(&addr);
1853         return connect(fd, (const struct sockaddr*)&addr, len);
1854       }));
1855   return socket;
1856 }
1857
1858 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1859   // Start listening on a local port
1860   WriteCallbackBase writeCallback;
1861   ReadCallback readCallback(&writeCallback);
1862   HandshakeCallback handshakeCallback(&readCallback);
1863   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1864   TestSSLServer server(&acceptCallback, true);
1865
1866   EventBase evb;
1867
1868   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1869   ConnCallback ccb;
1870   socket->connect(&ccb, server.getAddress(), 30);
1871
1872   evb.loop();
1873   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1874
1875   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1876   evb.loop();
1877
1878   BlockingSocket sock(std::move(socket));
1879   // write()
1880   std::array<uint8_t, 128> buf;
1881   memset(buf.data(), 'a', buf.size());
1882   sock.write(buf.data(), buf.size());
1883
1884   // read()
1885   std::array<uint8_t, 128> readbuf;
1886   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1887   EXPECT_EQ(bytesRead, 128);
1888   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1889
1890   // close()
1891   sock.close();
1892 }
1893
1894 #if !defined(OPENSSL_IS_BORINGSSL)
1895 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1896   // Start listening on a local port
1897   ConnectTimeoutCallback acceptCallback;
1898   TestSSLServer server(&acceptCallback, true);
1899
1900   // Set up SSL context.
1901   auto sslContext = std::make_shared<SSLContext>();
1902
1903   // connect
1904   auto socket =
1905       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1906   socket->enableTFO();
1907   EXPECT_THROW(
1908       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
1909 }
1910 #endif
1911
1912 #if !defined(OPENSSL_IS_BORINGSSL)
1913 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1914   // Start listening on a local port
1915   ConnectTimeoutCallback acceptCallback;
1916   TestSSLServer server(&acceptCallback, true);
1917
1918   EventBase evb;
1919
1920   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1921   ConnCallback ccb;
1922   // Set a short timeout
1923   socket->connect(&ccb, server.getAddress(), 1);
1924
1925   evb.loop();
1926   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1927 }
1928 #endif
1929
1930 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
1931   // Start listening on a local port
1932   EmptyReadCallback readCallback;
1933   HandshakeCallback handshakeCallback(
1934       &readCallback, HandshakeCallback::EXPECT_ERROR);
1935   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
1936   TestSSLServer server(&acceptCallback, true);
1937
1938   EventBase evb;
1939
1940   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1941   ConnCallback ccb;
1942   socket->connect(&ccb, server.getAddress(), 100);
1943
1944   evb.loop();
1945   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1946   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
1947 }
1948
1949 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
1950   // Start listening on a local port
1951   EventBase evb;
1952
1953   // Hopefully nothing is listening on this address
1954   SocketAddress addr("127.0.0.1", 65535);
1955   auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
1956   ConnCallback ccb;
1957   socket->connect(&ccb, addr, 100);
1958
1959   evb.loop();
1960   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1961   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
1962 }
1963
1964 #endif
1965
1966 } // namespace
1967
1968 #ifdef SIGPIPE
1969 ///////////////////////////////////////////////////////////////////////////
1970 // init_unit_test_suite
1971 ///////////////////////////////////////////////////////////////////////////
1972 namespace {
1973 struct Initializer {
1974   Initializer() {
1975     signal(SIGPIPE, SIG_IGN);
1976   }
1977 };
1978 Initializer initializer;
1979 } // anonymous
1980 #endif