181a33451e226b0f227f9cc2b9688988a592ed7d
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <pthread.h>
19 #include <signal.h>
20
21 #include <folly/SocketAddress.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/EventBase.h>
24 #include <folly/portability/GMock.h>
25 #include <folly/portability/GTest.h>
26 #include <folly/portability/Sockets.h>
27 #include <folly/portability/Unistd.h>
28
29 #include <folly/io/async/test/BlockingSocket.h>
30
31 #include <fcntl.h>
32 #include <folly/io/Cursor.h>
33 #include <openssl/bio.h>
34 #include <sys/types.h>
35 #include <fstream>
36 #include <iostream>
37 #include <list>
38 #include <set>
39 #include <thread>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 using namespace testing;
49
50 namespace folly {
51 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
52 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
53 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
54
55 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
56 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
57 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
58
59 constexpr size_t SSLClient::kMaxReadBufferSz;
60 constexpr size_t SSLClient::kMaxReadsPerEvent;
61
62 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
63     : ctx_(new folly::SSLContext),
64       acb_(acb),
65       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
66   // Set up the SSL context
67   ctx_->loadCertificate(testCert);
68   ctx_->loadPrivateKey(testKey);
69   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
70
71   acb_->ctx_ = ctx_;
72   acb_->base_ = &evb_;
73
74   // Enable TFO
75   if (enableTFO) {
76     LOG(INFO) << "server TFO enabled";
77     socket_->setTFOEnabled(true, 1000);
78   }
79
80   // set up the listening socket
81   socket_->bind(0);
82   socket_->getAddress(&address_);
83   socket_->listen(100);
84   socket_->addAcceptCallback(acb_, &evb_);
85   socket_->startAccepting();
86
87   int ret = pthread_create(&thread_, nullptr, Main, this);
88   assert(ret == 0);
89   (void)ret;
90
91   std::cerr << "Accepting connections on " << address_ << std::endl;
92 }
93
94 void getfds(int fds[2]) {
95   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
96     FAIL() << "failed to create socketpair: " << strerror(errno);
97   }
98   for (int idx = 0; idx < 2; ++idx) {
99     int flags = fcntl(fds[idx], F_GETFL, 0);
100     if (flags == -1) {
101       FAIL() << "failed to get flags for socket " << idx << ": "
102              << strerror(errno);
103     }
104     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
105       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
106              << strerror(errno);
107     }
108   }
109 }
110
111 void getctx(
112   std::shared_ptr<folly::SSLContext> clientCtx,
113   std::shared_ptr<folly::SSLContext> serverCtx) {
114   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
115
116   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
117   serverCtx->loadCertificate(
118       testCert);
119   serverCtx->loadPrivateKey(
120       testKey);
121 }
122
123 void sslsocketpair(
124   EventBase* eventBase,
125   AsyncSSLSocket::UniquePtr* clientSock,
126   AsyncSSLSocket::UniquePtr* serverSock) {
127   auto clientCtx = std::make_shared<folly::SSLContext>();
128   auto serverCtx = std::make_shared<folly::SSLContext>();
129   int fds[2];
130   getfds(fds);
131   getctx(clientCtx, serverCtx);
132   clientSock->reset(new AsyncSSLSocket(
133                       clientCtx, eventBase, fds[0], false));
134   serverSock->reset(new AsyncSSLSocket(
135                       serverCtx, eventBase, fds[1], true));
136
137   // (*clientSock)->setSendTimeout(100);
138   // (*serverSock)->setSendTimeout(100);
139 }
140
141 // client protocol filters
142 bool clientProtoFilterPickPony(unsigned char** client,
143   unsigned int* client_len, const unsigned char*, unsigned int ) {
144   //the protocol string in length prefixed byte string. the
145   //length byte is not included in the length
146   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
147   *client = p;
148   *client_len = 7;
149   return true;
150 }
151
152 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
153   const unsigned char*, unsigned int) {
154   return false;
155 }
156
157 std::string getFileAsBuf(const char* fileName) {
158   std::string buffer;
159   folly::readFile(fileName, buffer);
160   return buffer;
161 }
162
163 std::string getCommonName(X509* cert) {
164   X509_NAME* subject = X509_get_subject_name(cert);
165   std::string cn;
166   cn.resize(ub_common_name);
167   X509_NAME_get_text_by_NID(
168       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
169   return cn;
170 }
171
172 /**
173  * Test connecting to, writing to, reading from, and closing the
174  * connection to the SSL server.
175  */
176 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
177   // Start listening on a local port
178   WriteCallbackBase writeCallback;
179   ReadCallback readCallback(&writeCallback);
180   HandshakeCallback handshakeCallback(&readCallback);
181   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
182   TestSSLServer server(&acceptCallback);
183
184   // Set up SSL context.
185   std::shared_ptr<SSLContext> sslContext(new SSLContext());
186   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
187   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
188   //sslContext->authenticate(true, false);
189
190   // connect
191   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
192                                                  sslContext);
193   socket->open();
194
195   // write()
196   uint8_t buf[128];
197   memset(buf, 'a', sizeof(buf));
198   socket->write(buf, sizeof(buf));
199
200   // read()
201   uint8_t readbuf[128];
202   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
203   EXPECT_EQ(bytesRead, 128);
204   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
205
206   // close()
207   socket->close();
208
209   cerr << "ConnectWriteReadClose test completed" << endl;
210 }
211
212 /**
213  * Test reading after server close.
214  */
215 TEST(AsyncSSLSocketTest, ReadAfterClose) {
216   // Start listening on a local port
217   WriteCallbackBase writeCallback;
218   ReadEOFCallback readCallback(&writeCallback);
219   HandshakeCallback handshakeCallback(&readCallback);
220   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
221   auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
222
223   // Set up SSL context.
224   auto sslContext = std::make_shared<SSLContext>();
225   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
226
227   auto socket =
228       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
229   socket->open();
230
231   // This should trigger an EOF on the client.
232   auto evb = handshakeCallback.getSocket()->getEventBase();
233   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
234   std::array<uint8_t, 128> readbuf;
235   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
236   EXPECT_EQ(0, bytesRead);
237 }
238
239 /**
240  * Test bad renegotiation
241  */
242 TEST(AsyncSSLSocketTest, Renegotiate) {
243   EventBase eventBase;
244   auto clientCtx = std::make_shared<SSLContext>();
245   auto dfServerCtx = std::make_shared<SSLContext>();
246   std::array<int, 2> fds;
247   getfds(fds.data());
248   getctx(clientCtx, dfServerCtx);
249
250   AsyncSSLSocket::UniquePtr clientSock(
251       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
252   AsyncSSLSocket::UniquePtr serverSock(
253       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
254   SSLHandshakeClient client(std::move(clientSock), true, true);
255   RenegotiatingServer server(std::move(serverSock));
256
257   while (!client.handshakeSuccess_ && !client.handshakeError_) {
258     eventBase.loopOnce();
259   }
260
261   ASSERT_TRUE(client.handshakeSuccess_);
262
263   auto sslSock = std::move(client).moveSocket();
264   sslSock->detachEventBase();
265   // This is nasty, however we don't want to add support for
266   // renegotiation in AsyncSSLSocket.
267   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
268
269   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
270
271   std::thread t([&]() { eventBase.loopForever(); });
272
273   // Trigger the renegotiation.
274   std::array<uint8_t, 128> buf;
275   memset(buf.data(), 'a', buf.size());
276   try {
277     socket->write(buf.data(), buf.size());
278   } catch (AsyncSocketException& e) {
279     LOG(INFO) << "client got error " << e.what();
280   }
281   eventBase.terminateLoopSoon();
282   t.join();
283
284   eventBase.loop();
285   ASSERT_TRUE(server.renegotiationError_);
286 }
287
288 /**
289  * Negative test for handshakeError().
290  */
291 TEST(AsyncSSLSocketTest, HandshakeError) {
292   // Start listening on a local port
293   WriteCallbackBase writeCallback;
294   WriteErrorCallback readCallback(&writeCallback);
295   HandshakeCallback handshakeCallback(&readCallback);
296   HandshakeErrorCallback acceptCallback(&handshakeCallback);
297   TestSSLServer server(&acceptCallback);
298
299   // Set up SSL context.
300   std::shared_ptr<SSLContext> sslContext(new SSLContext());
301   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
302
303   // connect
304   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
305                                                  sslContext);
306   // read()
307   bool ex = false;
308   try {
309     socket->open();
310
311     uint8_t readbuf[128];
312     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
313     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
314   } catch (AsyncSocketException&) {
315     ex = true;
316   }
317   EXPECT_TRUE(ex);
318
319   // close()
320   socket->close();
321   cerr << "HandshakeError test completed" << endl;
322 }
323
324 /**
325  * Negative test for readError().
326  */
327 TEST(AsyncSSLSocketTest, ReadError) {
328   // Start listening on a local port
329   WriteCallbackBase writeCallback;
330   ReadErrorCallback readCallback(&writeCallback);
331   HandshakeCallback handshakeCallback(&readCallback);
332   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
333   TestSSLServer server(&acceptCallback);
334
335   // Set up SSL context.
336   std::shared_ptr<SSLContext> sslContext(new SSLContext());
337   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
338
339   // connect
340   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
341                                                  sslContext);
342   socket->open();
343
344   // write something to trigger ssl handshake
345   uint8_t buf[128];
346   memset(buf, 'a', sizeof(buf));
347   socket->write(buf, sizeof(buf));
348
349   socket->close();
350   cerr << "ReadError test completed" << endl;
351 }
352
353 /**
354  * Negative test for writeError().
355  */
356 TEST(AsyncSSLSocketTest, WriteError) {
357   // Start listening on a local port
358   WriteCallbackBase writeCallback;
359   WriteErrorCallback readCallback(&writeCallback);
360   HandshakeCallback handshakeCallback(&readCallback);
361   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
362   TestSSLServer server(&acceptCallback);
363
364   // Set up SSL context.
365   std::shared_ptr<SSLContext> sslContext(new SSLContext());
366   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
367
368   // connect
369   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
370                                                  sslContext);
371   socket->open();
372
373   // write something to trigger ssl handshake
374   uint8_t buf[128];
375   memset(buf, 'a', sizeof(buf));
376   socket->write(buf, sizeof(buf));
377
378   socket->close();
379   cerr << "WriteError test completed" << endl;
380 }
381
382 /**
383  * Test a socket with TCP_NODELAY unset.
384  */
385 TEST(AsyncSSLSocketTest, SocketWithDelay) {
386   // Start listening on a local port
387   WriteCallbackBase writeCallback;
388   ReadCallback readCallback(&writeCallback);
389   HandshakeCallback handshakeCallback(&readCallback);
390   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
391   TestSSLServer server(&acceptCallback);
392
393   // Set up SSL context.
394   std::shared_ptr<SSLContext> sslContext(new SSLContext());
395   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
396
397   // connect
398   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
399                                                  sslContext);
400   socket->open();
401
402   // write()
403   uint8_t buf[128];
404   memset(buf, 'a', sizeof(buf));
405   socket->write(buf, sizeof(buf));
406
407   // read()
408   uint8_t readbuf[128];
409   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
410   EXPECT_EQ(bytesRead, 128);
411   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
412
413   // close()
414   socket->close();
415
416   cerr << "SocketWithDelay test completed" << endl;
417 }
418
419 using NextProtocolTypePair =
420     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
421
422 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
423   // For matching protos
424  public:
425   void SetUp() override { getctx(clientCtx, serverCtx); }
426
427   void connect(bool unset = false) {
428     getfds(fds);
429
430     if (unset) {
431       // unsetting NPN for any of [client, server] is enough to make NPN not
432       // work
433       clientCtx->unsetNextProtocols();
434     }
435
436     AsyncSSLSocket::UniquePtr clientSock(
437       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
438     AsyncSSLSocket::UniquePtr serverSock(
439       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
440     client = folly::make_unique<NpnClient>(std::move(clientSock));
441     server = folly::make_unique<NpnServer>(std::move(serverSock));
442
443     eventBase.loop();
444   }
445
446   void expectProtocol(const std::string& proto) {
447     EXPECT_NE(client->nextProtoLength, 0);
448     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
449     EXPECT_EQ(
450         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
451         0);
452     string selected((const char*)client->nextProto, client->nextProtoLength);
453     EXPECT_EQ(proto, selected);
454   }
455
456   void expectNoProtocol() {
457     EXPECT_EQ(client->nextProtoLength, 0);
458     EXPECT_EQ(server->nextProtoLength, 0);
459     EXPECT_EQ(client->nextProto, nullptr);
460     EXPECT_EQ(server->nextProto, nullptr);
461   }
462
463   void expectProtocolType() {
464     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
465         GetParam().second == SSLContext::NextProtocolType::ANY) {
466       EXPECT_EQ(client->protocolType, server->protocolType);
467     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
468                GetParam().second == SSLContext::NextProtocolType::ANY) {
469       // Well not much we can say
470     } else {
471       expectProtocolType(GetParam());
472     }
473   }
474
475   void expectProtocolType(NextProtocolTypePair expected) {
476     EXPECT_EQ(client->protocolType, expected.first);
477     EXPECT_EQ(server->protocolType, expected.second);
478   }
479
480   EventBase eventBase;
481   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
482   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
483   int fds[2];
484   std::unique_ptr<NpnClient> client;
485   std::unique_ptr<NpnServer> server;
486 };
487
488 class NextProtocolTLSExtTest : public NextProtocolTest {
489   // For extended TLS protos
490 };
491
492 class NextProtocolNPNOnlyTest : public NextProtocolTest {
493   // For mismatching protos
494 };
495
496 class NextProtocolMismatchTest : public NextProtocolTest {
497   // For mismatching protos
498 };
499
500 TEST_P(NextProtocolTest, NpnTestOverlap) {
501   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
502   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
503                                         GetParam().second);
504
505   connect();
506
507   expectProtocol("baz");
508   expectProtocolType();
509 }
510
511 TEST_P(NextProtocolTest, NpnTestUnset) {
512   // Identical to above test, except that we want unset NPN before
513   // looping.
514   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
515   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
516                                         GetParam().second);
517
518   connect(true /* unset */);
519
520   // if alpn negotiation fails, type will appear as npn
521   expectNoProtocol();
522   EXPECT_EQ(client->protocolType, server->protocolType);
523 }
524
525 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
526   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
527   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
528                                         GetParam().second);
529
530   connect();
531
532   expectNoProtocol();
533   expectProtocolType(
534       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
535 }
536
537 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
538 // will fail on 1.0.2 before that.
539 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
540   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
541   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
542                                         GetParam().second);
543
544   connect();
545
546   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
547       GetParam().second == SSLContext::NextProtocolType::ALPN) {
548     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
549     // mismatch should result in a fatal alert, but this is OpenSSL's current
550     // behavior and we want to know if it changes.
551     expectNoProtocol();
552   } else {
553     expectProtocol("blub");
554     expectProtocolType(
555         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
556   }
557 }
558
559 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
560   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
561   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
562   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
563                                         GetParam().second);
564
565   connect();
566
567   expectProtocol("ponies");
568   expectProtocolType();
569 }
570
571 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
572   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
573   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
574   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
575                                         GetParam().second);
576
577   connect();
578
579   expectProtocol("blub");
580   expectProtocolType();
581 }
582
583 TEST_P(NextProtocolTest, RandomizedNpnTest) {
584   // Probability that this test will fail is 2^-64, which could be considered
585   // as negligible.
586   const int kTries = 64;
587
588   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
589                                         GetParam().first);
590   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
591                                                   GetParam().second);
592
593   std::set<string> selectedProtocols;
594   for (int i = 0; i < kTries; ++i) {
595     connect();
596
597     EXPECT_NE(client->nextProtoLength, 0);
598     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
599     EXPECT_EQ(
600         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
601         0);
602     string selected((const char*)client->nextProto, client->nextProtoLength);
603     selectedProtocols.insert(selected);
604     expectProtocolType();
605   }
606   EXPECT_EQ(selectedProtocols.size(), 2);
607 }
608
609 INSTANTIATE_TEST_CASE_P(
610     AsyncSSLSocketTest,
611     NextProtocolTest,
612     ::testing::Values(
613         NextProtocolTypePair(
614             SSLContext::NextProtocolType::NPN,
615             SSLContext::NextProtocolType::NPN),
616         NextProtocolTypePair(
617             SSLContext::NextProtocolType::NPN,
618             SSLContext::NextProtocolType::ANY),
619         NextProtocolTypePair(
620             SSLContext::NextProtocolType::ANY,
621             SSLContext::NextProtocolType::ANY)));
622
623 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
624 INSTANTIATE_TEST_CASE_P(
625     AsyncSSLSocketTest,
626     NextProtocolTLSExtTest,
627     ::testing::Values(
628         NextProtocolTypePair(
629             SSLContext::NextProtocolType::ALPN,
630             SSLContext::NextProtocolType::ALPN),
631         NextProtocolTypePair(
632             SSLContext::NextProtocolType::ALPN,
633             SSLContext::NextProtocolType::ANY),
634         NextProtocolTypePair(
635             SSLContext::NextProtocolType::ANY,
636             SSLContext::NextProtocolType::ALPN)));
637 #endif
638
639 INSTANTIATE_TEST_CASE_P(
640     AsyncSSLSocketTest,
641     NextProtocolNPNOnlyTest,
642     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
643                                            SSLContext::NextProtocolType::NPN)));
644
645 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
646 INSTANTIATE_TEST_CASE_P(
647     AsyncSSLSocketTest,
648     NextProtocolMismatchTest,
649     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
650                                            SSLContext::NextProtocolType::ALPN),
651                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
652                                            SSLContext::NextProtocolType::NPN)));
653 #endif
654
655 #ifndef OPENSSL_NO_TLSEXT
656 /**
657  * 1. Client sends TLSEXT_HOSTNAME in client hello.
658  * 2. Server found a match SSL_CTX and use this SSL_CTX to
659  *    continue the SSL handshake.
660  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
661  */
662 TEST(AsyncSSLSocketTest, SNITestMatch) {
663   EventBase eventBase;
664   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
665   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
666   // Use the same SSLContext to continue the handshake after
667   // tlsext_hostname match.
668   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
669   const std::string serverName("xyz.newdev.facebook.com");
670   int fds[2];
671   getfds(fds);
672   getctx(clientCtx, dfServerCtx);
673
674   AsyncSSLSocket::UniquePtr clientSock(
675     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
676   AsyncSSLSocket::UniquePtr serverSock(
677     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
678   SNIClient client(std::move(clientSock));
679   SNIServer server(std::move(serverSock),
680                    dfServerCtx,
681                    hskServerCtx,
682                    serverName);
683
684   eventBase.loop();
685
686   EXPECT_TRUE(client.serverNameMatch);
687   EXPECT_TRUE(server.serverNameMatch);
688 }
689
690 /**
691  * 1. Client sends TLSEXT_HOSTNAME in client hello.
692  * 2. Server cannot find a matching SSL_CTX and continue to use
693  *    the current SSL_CTX to do the handshake.
694  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
695  */
696 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
697   EventBase eventBase;
698   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
699   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
700   // Use the same SSLContext to continue the handshake after
701   // tlsext_hostname match.
702   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
703   const std::string clientRequestingServerName("foo.com");
704   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
705
706   int fds[2];
707   getfds(fds);
708   getctx(clientCtx, dfServerCtx);
709
710   AsyncSSLSocket::UniquePtr clientSock(
711     new AsyncSSLSocket(clientCtx,
712                         &eventBase,
713                         fds[0],
714                         clientRequestingServerName));
715   AsyncSSLSocket::UniquePtr serverSock(
716     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
717   SNIClient client(std::move(clientSock));
718   SNIServer server(std::move(serverSock),
719                    dfServerCtx,
720                    hskServerCtx,
721                    serverExpectedServerName);
722
723   eventBase.loop();
724
725   EXPECT_TRUE(!client.serverNameMatch);
726   EXPECT_TRUE(!server.serverNameMatch);
727 }
728 /**
729  * 1. Client sends TLSEXT_HOSTNAME in client hello.
730  * 2. We then change the serverName.
731  * 3. We expect that we get 'false' as the result for serNameMatch.
732  */
733
734 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
735    EventBase eventBase;
736   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
737   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
738   // Use the same SSLContext to continue the handshake after
739   // tlsext_hostname match.
740   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
741   const std::string serverName("xyz.newdev.facebook.com");
742   int fds[2];
743   getfds(fds);
744   getctx(clientCtx, dfServerCtx);
745
746   AsyncSSLSocket::UniquePtr clientSock(
747     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
748   //Change the server name
749   std::string newName("new.com");
750   clientSock->setServerName(newName);
751   AsyncSSLSocket::UniquePtr serverSock(
752     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
753   SNIClient client(std::move(clientSock));
754   SNIServer server(std::move(serverSock),
755                    dfServerCtx,
756                    hskServerCtx,
757                    serverName);
758
759   eventBase.loop();
760
761   EXPECT_TRUE(!client.serverNameMatch);
762 }
763
764 /**
765  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
766  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
767  */
768 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
769   EventBase eventBase;
770   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
771   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
772   // Use the same SSLContext to continue the handshake after
773   // tlsext_hostname match.
774   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
775   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
776
777   int fds[2];
778   getfds(fds);
779   getctx(clientCtx, dfServerCtx);
780
781   AsyncSSLSocket::UniquePtr clientSock(
782     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
783   AsyncSSLSocket::UniquePtr serverSock(
784     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
785   SNIClient client(std::move(clientSock));
786   SNIServer server(std::move(serverSock),
787                    dfServerCtx,
788                    hskServerCtx,
789                    serverExpectedServerName);
790
791   eventBase.loop();
792
793   EXPECT_TRUE(!client.serverNameMatch);
794   EXPECT_TRUE(!server.serverNameMatch);
795 }
796
797 #endif
798 /**
799  * Test SSL client socket
800  */
801 TEST(AsyncSSLSocketTest, SSLClientTest) {
802   // Start listening on a local port
803   WriteCallbackBase writeCallback;
804   ReadCallback readCallback(&writeCallback);
805   HandshakeCallback handshakeCallback(&readCallback);
806   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
807   TestSSLServer server(&acceptCallback);
808
809   // Set up SSL client
810   EventBase eventBase;
811   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
812
813   client->connect();
814   EventBaseAborter eba(&eventBase, 3000);
815   eventBase.loop();
816
817   EXPECT_EQ(client->getMiss(), 1);
818   EXPECT_EQ(client->getHit(), 0);
819
820   cerr << "SSLClientTest test completed" << endl;
821 }
822
823
824 /**
825  * Test SSL client socket session re-use
826  */
827 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
828   // Start listening on a local port
829   WriteCallbackBase writeCallback;
830   ReadCallback readCallback(&writeCallback);
831   HandshakeCallback handshakeCallback(&readCallback);
832   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
833   TestSSLServer server(&acceptCallback);
834
835   // Set up SSL client
836   EventBase eventBase;
837   auto client =
838       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
839
840   client->connect();
841   EventBaseAborter eba(&eventBase, 3000);
842   eventBase.loop();
843
844   EXPECT_EQ(client->getMiss(), 1);
845   EXPECT_EQ(client->getHit(), 9);
846
847   cerr << "SSLClientTestReuse test completed" << endl;
848 }
849
850 /**
851  * Test SSL client socket timeout
852  */
853 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
854   // Start listening on a local port
855   EmptyReadCallback readCallback;
856   HandshakeCallback handshakeCallback(&readCallback,
857                                       HandshakeCallback::EXPECT_ERROR);
858   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
859   TestSSLServer server(&acceptCallback);
860
861   // Set up SSL client
862   EventBase eventBase;
863   auto client =
864       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
865   client->connect(true /* write before connect completes */);
866   EventBaseAborter eba(&eventBase, 3000);
867   eventBase.loop();
868
869   usleep(100000);
870   // This is checking that the connectError callback precedes any queued
871   // writeError callbacks.  This matches AsyncSocket's behavior
872   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
873   EXPECT_EQ(client->getErrors(), 1);
874   EXPECT_EQ(client->getMiss(), 0);
875   EXPECT_EQ(client->getHit(), 0);
876
877   cerr << "SSLClientTimeoutTest test completed" << endl;
878 }
879
880 // This is a FB-only extension, and the tests will fail without it
881 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
882 /**
883  * Test SSL server async cache
884  */
885 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
886   // Start listening on a local port
887   WriteCallbackBase writeCallback;
888   ReadCallback readCallback(&writeCallback);
889   HandshakeCallback handshakeCallback(&readCallback);
890   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
891   TestSSLAsyncCacheServer server(&acceptCallback);
892
893   // Set up SSL client
894   EventBase eventBase;
895   auto client =
896       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
897
898   client->connect();
899   EventBaseAborter eba(&eventBase, 3000);
900   eventBase.loop();
901
902   EXPECT_EQ(server.getAsyncCallbacks(), 18);
903   EXPECT_EQ(server.getAsyncLookups(), 9);
904   EXPECT_EQ(client->getMiss(), 10);
905   EXPECT_EQ(client->getHit(), 0);
906
907   cerr << "SSLServerAsyncCacheTest test completed" << endl;
908 }
909
910
911 /**
912  * Test SSL server accept timeout with cache path
913  */
914 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
915   // Start listening on a local port
916   WriteCallbackBase writeCallback;
917   ReadCallback readCallback(&writeCallback);
918   HandshakeCallback handshakeCallback(&readCallback);
919   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
920   TestSSLAsyncCacheServer server(&acceptCallback);
921
922   // Set up SSL client
923   EventBase eventBase;
924   // only do a TCP connect
925   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
926   sock->connect(nullptr, server.getAddress());
927
928   EmptyReadCallback clientReadCallback;
929   clientReadCallback.tcpSocket_ = sock;
930   sock->setReadCB(&clientReadCallback);
931
932   EventBaseAborter eba(&eventBase, 3000);
933   eventBase.loop();
934
935   EXPECT_EQ(readCallback.state, STATE_WAITING);
936
937   cerr << "SSLServerTimeoutTest test completed" << endl;
938 }
939
940 /**
941  * Test SSL server accept timeout with cache path
942  */
943 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
944   // Start listening on a local port
945   WriteCallbackBase writeCallback;
946   ReadCallback readCallback(&writeCallback);
947   HandshakeCallback handshakeCallback(&readCallback);
948   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
949   TestSSLAsyncCacheServer server(&acceptCallback);
950
951   // Set up SSL client
952   EventBase eventBase;
953   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
954
955   client->connect();
956   EventBaseAborter eba(&eventBase, 3000);
957   eventBase.loop();
958
959   EXPECT_EQ(server.getAsyncCallbacks(), 1);
960   EXPECT_EQ(server.getAsyncLookups(), 1);
961   EXPECT_EQ(client->getErrors(), 1);
962   EXPECT_EQ(client->getMiss(), 1);
963   EXPECT_EQ(client->getHit(), 0);
964
965   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
966 }
967
968 /**
969  * Test SSL server accept timeout with cache path
970  */
971 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
972   // Start listening on a local port
973   WriteCallbackBase writeCallback;
974   ReadCallback readCallback(&writeCallback);
975   HandshakeCallback handshakeCallback(&readCallback,
976                                       HandshakeCallback::EXPECT_ERROR);
977   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
978   TestSSLAsyncCacheServer server(&acceptCallback, 500);
979
980   // Set up SSL client
981   EventBase eventBase;
982   auto client =
983       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
984
985   client->connect();
986   EventBaseAborter eba(&eventBase, 3000);
987   eventBase.loop();
988
989   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
990       handshakeCallback.closeSocket();});
991   // give time for the cache lookup to come back and find it closed
992   handshakeCallback.waitForHandshake();
993
994   EXPECT_EQ(server.getAsyncCallbacks(), 1);
995   EXPECT_EQ(server.getAsyncLookups(), 1);
996   EXPECT_EQ(client->getErrors(), 1);
997   EXPECT_EQ(client->getMiss(), 1);
998   EXPECT_EQ(client->getHit(), 0);
999
1000   cerr << "SSLServerCacheCloseTest test completed" << endl;
1001 }
1002 #endif
1003
1004 /**
1005  * Verify Client Ciphers obtained using SSL MSG Callback.
1006  */
1007 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1008   EventBase eventBase;
1009   auto clientCtx = std::make_shared<SSLContext>();
1010   auto serverCtx = std::make_shared<SSLContext>();
1011   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1012   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1013   serverCtx->loadPrivateKey(testKey);
1014   serverCtx->loadCertificate(testCert);
1015   serverCtx->loadTrustedCertificates(testCA);
1016   serverCtx->loadClientCAList(testCA);
1017
1018   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1019   clientCtx->ciphers("AES256-SHA:RC4-MD5");
1020   clientCtx->loadPrivateKey(testKey);
1021   clientCtx->loadCertificate(testCert);
1022   clientCtx->loadTrustedCertificates(testCA);
1023
1024   int fds[2];
1025   getfds(fds);
1026
1027   AsyncSSLSocket::UniquePtr clientSock(
1028       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1029   AsyncSSLSocket::UniquePtr serverSock(
1030       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1031
1032   SSLHandshakeClient client(std::move(clientSock), true, true);
1033   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1034
1035   eventBase.loop();
1036
1037   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:RC4-MD5:00ff");
1038   EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1039   EXPECT_TRUE(client.handshakeVerify_);
1040   EXPECT_TRUE(client.handshakeSuccess_);
1041   EXPECT_TRUE(!client.handshakeError_);
1042   EXPECT_TRUE(server.handshakeVerify_);
1043   EXPECT_TRUE(server.handshakeSuccess_);
1044   EXPECT_TRUE(!server.handshakeError_);
1045 }
1046
1047 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1048   EventBase eventBase;
1049   auto ctx = std::make_shared<SSLContext>();
1050
1051   int fds[2];
1052   getfds(fds);
1053
1054   int bufLen = 42;
1055   uint8_t majorVersion = 18;
1056   uint8_t minorVersion = 25;
1057
1058   // Create callback buf
1059   auto buf = IOBuf::create(bufLen);
1060   buf->append(bufLen);
1061   folly::io::RWPrivateCursor cursor(buf.get());
1062   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1063   cursor.write<uint16_t>(0);
1064   cursor.write<uint8_t>(38);
1065   cursor.write<uint8_t>(majorVersion);
1066   cursor.write<uint8_t>(minorVersion);
1067   cursor.skip(32);
1068   cursor.write<uint32_t>(0);
1069
1070   SSL* ssl = ctx->createSSL();
1071   SCOPE_EXIT { SSL_free(ssl); };
1072   AsyncSSLSocket::UniquePtr sock(
1073       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1074   sock->enableClientHelloParsing();
1075
1076   // Test client hello parsing in one packet
1077   AsyncSSLSocket::clientHelloParsingCallback(
1078       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1079   buf.reset();
1080
1081   auto parsedClientHello = sock->getClientHelloInfo();
1082   EXPECT_TRUE(parsedClientHello != nullptr);
1083   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1084   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1085 }
1086
1087 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1088   EventBase eventBase;
1089   auto ctx = std::make_shared<SSLContext>();
1090
1091   int fds[2];
1092   getfds(fds);
1093
1094   int bufLen = 42;
1095   uint8_t majorVersion = 18;
1096   uint8_t minorVersion = 25;
1097
1098   // Create callback buf
1099   auto buf = IOBuf::create(bufLen);
1100   buf->append(bufLen);
1101   folly::io::RWPrivateCursor cursor(buf.get());
1102   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1103   cursor.write<uint16_t>(0);
1104   cursor.write<uint8_t>(38);
1105   cursor.write<uint8_t>(majorVersion);
1106   cursor.write<uint8_t>(minorVersion);
1107   cursor.skip(32);
1108   cursor.write<uint32_t>(0);
1109
1110   SSL* ssl = ctx->createSSL();
1111   SCOPE_EXIT { SSL_free(ssl); };
1112   AsyncSSLSocket::UniquePtr sock(
1113       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1114   sock->enableClientHelloParsing();
1115
1116   // Test parsing with two packets with first packet size < 3
1117   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1118   AsyncSSLSocket::clientHelloParsingCallback(
1119       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1120       ssl, sock.get());
1121   bufCopy.reset();
1122   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1123   AsyncSSLSocket::clientHelloParsingCallback(
1124       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1125       ssl, sock.get());
1126   bufCopy.reset();
1127
1128   auto parsedClientHello = sock->getClientHelloInfo();
1129   EXPECT_TRUE(parsedClientHello != nullptr);
1130   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1131   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1132 }
1133
1134 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1135   EventBase eventBase;
1136   auto ctx = std::make_shared<SSLContext>();
1137
1138   int fds[2];
1139   getfds(fds);
1140
1141   int bufLen = 42;
1142   uint8_t majorVersion = 18;
1143   uint8_t minorVersion = 25;
1144
1145   // Create callback buf
1146   auto buf = IOBuf::create(bufLen);
1147   buf->append(bufLen);
1148   folly::io::RWPrivateCursor cursor(buf.get());
1149   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1150   cursor.write<uint16_t>(0);
1151   cursor.write<uint8_t>(38);
1152   cursor.write<uint8_t>(majorVersion);
1153   cursor.write<uint8_t>(minorVersion);
1154   cursor.skip(32);
1155   cursor.write<uint32_t>(0);
1156
1157   SSL* ssl = ctx->createSSL();
1158   SCOPE_EXIT { SSL_free(ssl); };
1159   AsyncSSLSocket::UniquePtr sock(
1160       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1161   sock->enableClientHelloParsing();
1162
1163   // Test parsing with multiple small packets
1164   for (uint64_t i = 0; i < buf->length(); i += 3) {
1165     auto bufCopy = folly::IOBuf::copyBuffer(
1166         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1167     AsyncSSLSocket::clientHelloParsingCallback(
1168         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1169         ssl, sock.get());
1170     bufCopy.reset();
1171   }
1172
1173   auto parsedClientHello = sock->getClientHelloInfo();
1174   EXPECT_TRUE(parsedClientHello != nullptr);
1175   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1176   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1177 }
1178
1179 /**
1180  * Verify sucessful behavior of SSL certificate validation.
1181  */
1182 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1183   EventBase eventBase;
1184   auto clientCtx = std::make_shared<SSLContext>();
1185   auto dfServerCtx = std::make_shared<SSLContext>();
1186
1187   int fds[2];
1188   getfds(fds);
1189   getctx(clientCtx, dfServerCtx);
1190
1191   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1192   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1193
1194   AsyncSSLSocket::UniquePtr clientSock(
1195     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1196   AsyncSSLSocket::UniquePtr serverSock(
1197     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1198
1199   SSLHandshakeClient client(std::move(clientSock), true, true);
1200   clientCtx->loadTrustedCertificates(testCA);
1201
1202   SSLHandshakeServer server(std::move(serverSock), true, true);
1203
1204   eventBase.loop();
1205
1206   EXPECT_TRUE(client.handshakeVerify_);
1207   EXPECT_TRUE(client.handshakeSuccess_);
1208   EXPECT_TRUE(!client.handshakeError_);
1209   EXPECT_LE(0, client.handshakeTime.count());
1210   EXPECT_TRUE(!server.handshakeVerify_);
1211   EXPECT_TRUE(server.handshakeSuccess_);
1212   EXPECT_TRUE(!server.handshakeError_);
1213   EXPECT_LE(0, server.handshakeTime.count());
1214 }
1215
1216 /**
1217  * Verify that the client's verification callback is able to fail SSL
1218  * connection establishment.
1219  */
1220 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1221   EventBase eventBase;
1222   auto clientCtx = std::make_shared<SSLContext>();
1223   auto dfServerCtx = std::make_shared<SSLContext>();
1224
1225   int fds[2];
1226   getfds(fds);
1227   getctx(clientCtx, dfServerCtx);
1228
1229   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1230   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1231
1232   AsyncSSLSocket::UniquePtr clientSock(
1233     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1234   AsyncSSLSocket::UniquePtr serverSock(
1235     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1236
1237   SSLHandshakeClient client(std::move(clientSock), true, false);
1238   clientCtx->loadTrustedCertificates(testCA);
1239
1240   SSLHandshakeServer server(std::move(serverSock), true, true);
1241
1242   eventBase.loop();
1243
1244   EXPECT_TRUE(client.handshakeVerify_);
1245   EXPECT_TRUE(!client.handshakeSuccess_);
1246   EXPECT_TRUE(client.handshakeError_);
1247   EXPECT_LE(0, client.handshakeTime.count());
1248   EXPECT_TRUE(!server.handshakeVerify_);
1249   EXPECT_TRUE(!server.handshakeSuccess_);
1250   EXPECT_TRUE(server.handshakeError_);
1251   EXPECT_LE(0, server.handshakeTime.count());
1252 }
1253
1254 /**
1255  * Verify that the options in SSLContext can be overridden in
1256  * sslConnect/Accept.i.e specifying that no validation should be performed
1257  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1258  * the validation callback.
1259  */
1260 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1261   EventBase eventBase;
1262   auto clientCtx = std::make_shared<SSLContext>();
1263   auto dfServerCtx = std::make_shared<SSLContext>();
1264
1265   int fds[2];
1266   getfds(fds);
1267   getctx(clientCtx, dfServerCtx);
1268
1269   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1270   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1271
1272   AsyncSSLSocket::UniquePtr clientSock(
1273     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1274   AsyncSSLSocket::UniquePtr serverSock(
1275     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1276
1277   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1278   clientCtx->loadTrustedCertificates(testCA);
1279
1280   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1281
1282   eventBase.loop();
1283
1284   EXPECT_TRUE(!client.handshakeVerify_);
1285   EXPECT_TRUE(client.handshakeSuccess_);
1286   EXPECT_TRUE(!client.handshakeError_);
1287   EXPECT_LE(0, client.handshakeTime.count());
1288   EXPECT_TRUE(!server.handshakeVerify_);
1289   EXPECT_TRUE(server.handshakeSuccess_);
1290   EXPECT_TRUE(!server.handshakeError_);
1291   EXPECT_LE(0, server.handshakeTime.count());
1292 }
1293
1294 /**
1295  * Verify that the options in SSLContext can be overridden in
1296  * sslConnect/Accept. Enable verification even if context says otherwise.
1297  * Test requireClientCert with client cert
1298  */
1299 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1300   EventBase eventBase;
1301   auto clientCtx = std::make_shared<SSLContext>();
1302   auto serverCtx = std::make_shared<SSLContext>();
1303   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1304   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1305   serverCtx->loadPrivateKey(testKey);
1306   serverCtx->loadCertificate(testCert);
1307   serverCtx->loadTrustedCertificates(testCA);
1308   serverCtx->loadClientCAList(testCA);
1309
1310   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1311   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1312   clientCtx->loadPrivateKey(testKey);
1313   clientCtx->loadCertificate(testCert);
1314   clientCtx->loadTrustedCertificates(testCA);
1315
1316   int fds[2];
1317   getfds(fds);
1318
1319   AsyncSSLSocket::UniquePtr clientSock(
1320       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1321   AsyncSSLSocket::UniquePtr serverSock(
1322       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1323
1324   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1325   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1326
1327   eventBase.loop();
1328
1329   EXPECT_TRUE(client.handshakeVerify_);
1330   EXPECT_TRUE(client.handshakeSuccess_);
1331   EXPECT_FALSE(client.handshakeError_);
1332   EXPECT_LE(0, client.handshakeTime.count());
1333   EXPECT_TRUE(server.handshakeVerify_);
1334   EXPECT_TRUE(server.handshakeSuccess_);
1335   EXPECT_FALSE(server.handshakeError_);
1336   EXPECT_LE(0, server.handshakeTime.count());
1337 }
1338
1339 /**
1340  * Verify that the client's verification callback is able to override
1341  * the preverification failure and allow a successful connection.
1342  */
1343 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1344   EventBase eventBase;
1345   auto clientCtx = std::make_shared<SSLContext>();
1346   auto dfServerCtx = std::make_shared<SSLContext>();
1347
1348   int fds[2];
1349   getfds(fds);
1350   getctx(clientCtx, dfServerCtx);
1351
1352   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1353   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1354
1355   AsyncSSLSocket::UniquePtr clientSock(
1356     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1357   AsyncSSLSocket::UniquePtr serverSock(
1358     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1359
1360   SSLHandshakeClient client(std::move(clientSock), false, true);
1361   SSLHandshakeServer server(std::move(serverSock), true, true);
1362
1363   eventBase.loop();
1364
1365   EXPECT_TRUE(client.handshakeVerify_);
1366   EXPECT_TRUE(client.handshakeSuccess_);
1367   EXPECT_TRUE(!client.handshakeError_);
1368   EXPECT_LE(0, client.handshakeTime.count());
1369   EXPECT_TRUE(!server.handshakeVerify_);
1370   EXPECT_TRUE(server.handshakeSuccess_);
1371   EXPECT_TRUE(!server.handshakeError_);
1372   EXPECT_LE(0, server.handshakeTime.count());
1373 }
1374
1375 /**
1376  * Verify that specifying that no validation should be performed allows an
1377  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1378  * callback.
1379  */
1380 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1381   EventBase eventBase;
1382   auto clientCtx = std::make_shared<SSLContext>();
1383   auto dfServerCtx = std::make_shared<SSLContext>();
1384
1385   int fds[2];
1386   getfds(fds);
1387   getctx(clientCtx, dfServerCtx);
1388
1389   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1390   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1391
1392   AsyncSSLSocket::UniquePtr clientSock(
1393     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1394   AsyncSSLSocket::UniquePtr serverSock(
1395     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1396
1397   SSLHandshakeClient client(std::move(clientSock), false, false);
1398   SSLHandshakeServer server(std::move(serverSock), false, false);
1399
1400   eventBase.loop();
1401
1402   EXPECT_TRUE(!client.handshakeVerify_);
1403   EXPECT_TRUE(client.handshakeSuccess_);
1404   EXPECT_TRUE(!client.handshakeError_);
1405   EXPECT_LE(0, client.handshakeTime.count());
1406   EXPECT_TRUE(!server.handshakeVerify_);
1407   EXPECT_TRUE(server.handshakeSuccess_);
1408   EXPECT_TRUE(!server.handshakeError_);
1409   EXPECT_LE(0, server.handshakeTime.count());
1410 }
1411
1412 /**
1413  * Test requireClientCert with client cert
1414  */
1415 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1416   EventBase eventBase;
1417   auto clientCtx = std::make_shared<SSLContext>();
1418   auto serverCtx = std::make_shared<SSLContext>();
1419   serverCtx->setVerificationOption(
1420       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1421   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1422   serverCtx->loadPrivateKey(testKey);
1423   serverCtx->loadCertificate(testCert);
1424   serverCtx->loadTrustedCertificates(testCA);
1425   serverCtx->loadClientCAList(testCA);
1426
1427   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1428   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1429   clientCtx->loadPrivateKey(testKey);
1430   clientCtx->loadCertificate(testCert);
1431   clientCtx->loadTrustedCertificates(testCA);
1432
1433   int fds[2];
1434   getfds(fds);
1435
1436   AsyncSSLSocket::UniquePtr clientSock(
1437       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1438   AsyncSSLSocket::UniquePtr serverSock(
1439       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1440
1441   SSLHandshakeClient client(std::move(clientSock), true, true);
1442   SSLHandshakeServer server(std::move(serverSock), true, true);
1443
1444   eventBase.loop();
1445
1446   EXPECT_TRUE(client.handshakeVerify_);
1447   EXPECT_TRUE(client.handshakeSuccess_);
1448   EXPECT_FALSE(client.handshakeError_);
1449   EXPECT_LE(0, client.handshakeTime.count());
1450   EXPECT_TRUE(server.handshakeVerify_);
1451   EXPECT_TRUE(server.handshakeSuccess_);
1452   EXPECT_FALSE(server.handshakeError_);
1453   EXPECT_LE(0, server.handshakeTime.count());
1454 }
1455
1456
1457 /**
1458  * Test requireClientCert with no client cert
1459  */
1460 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1461   EventBase eventBase;
1462   auto clientCtx = std::make_shared<SSLContext>();
1463   auto serverCtx = std::make_shared<SSLContext>();
1464   serverCtx->setVerificationOption(
1465       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1466   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1467   serverCtx->loadPrivateKey(testKey);
1468   serverCtx->loadCertificate(testCert);
1469   serverCtx->loadTrustedCertificates(testCA);
1470   serverCtx->loadClientCAList(testCA);
1471   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1472   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1473
1474   int fds[2];
1475   getfds(fds);
1476
1477   AsyncSSLSocket::UniquePtr clientSock(
1478       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1479   AsyncSSLSocket::UniquePtr serverSock(
1480       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1481
1482   SSLHandshakeClient client(std::move(clientSock), false, false);
1483   SSLHandshakeServer server(std::move(serverSock), false, false);
1484
1485   eventBase.loop();
1486
1487   EXPECT_FALSE(server.handshakeVerify_);
1488   EXPECT_FALSE(server.handshakeSuccess_);
1489   EXPECT_TRUE(server.handshakeError_);
1490   EXPECT_LE(0, client.handshakeTime.count());
1491   EXPECT_LE(0, server.handshakeTime.count());
1492 }
1493
1494 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1495   auto cert = getFileAsBuf(testCert);
1496   auto key = getFileAsBuf(testKey);
1497
1498   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1499   BIO_write(certBio.get(), cert.data(), cert.size());
1500   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1501   BIO_write(keyBio.get(), key.data(), key.size());
1502
1503   // Create SSL structs from buffers to get properties
1504   ssl::X509UniquePtr certStruct(
1505       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1506   ssl::EvpPkeyUniquePtr keyStruct(
1507       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1508   certBio = nullptr;
1509   keyBio = nullptr;
1510
1511   auto origCommonName = getCommonName(certStruct.get());
1512   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1513   certStruct = nullptr;
1514   keyStruct = nullptr;
1515
1516   auto ctx = std::make_shared<SSLContext>();
1517   ctx->loadPrivateKeyFromBufferPEM(key);
1518   ctx->loadCertificateFromBufferPEM(cert);
1519   ctx->loadTrustedCertificates(testCA);
1520
1521   ssl::SSLUniquePtr ssl(ctx->createSSL());
1522
1523   auto newCert = SSL_get_certificate(ssl.get());
1524   auto newKey = SSL_get_privatekey(ssl.get());
1525
1526   // Get properties from SSL struct
1527   auto newCommonName = getCommonName(newCert);
1528   auto newKeySize = EVP_PKEY_bits(newKey);
1529
1530   // Check that the key and cert have the expected properties
1531   EXPECT_EQ(origCommonName, newCommonName);
1532   EXPECT_EQ(origKeySize, newKeySize);
1533 }
1534
1535 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1536   EventBase eb;
1537
1538   // Set up SSL context.
1539   auto sslContext = std::make_shared<SSLContext>();
1540   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1541
1542   // create SSL socket
1543   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1544
1545   EXPECT_EQ(1500, socket->getMinWriteSize());
1546
1547   socket->setMinWriteSize(0);
1548   EXPECT_EQ(0, socket->getMinWriteSize());
1549   socket->setMinWriteSize(50000);
1550   EXPECT_EQ(50000, socket->getMinWriteSize());
1551 }
1552
1553 class ReadCallbackTerminator : public ReadCallback {
1554  public:
1555   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1556       : ReadCallback(wcb)
1557       , base_(base) {}
1558
1559   // Do not write data back, terminate the loop.
1560   void readDataAvailable(size_t len) noexcept override {
1561     std::cerr << "readDataAvailable, len " << len << std::endl;
1562
1563     currentBuffer.length = len;
1564
1565     buffers.push_back(currentBuffer);
1566     currentBuffer.reset();
1567     state = STATE_SUCCEEDED;
1568
1569     socket_->setReadCB(nullptr);
1570     base_->terminateLoopSoon();
1571   }
1572  private:
1573   EventBase* base_;
1574 };
1575
1576
1577 /**
1578  * Test a full unencrypted codepath
1579  */
1580 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1581   EventBase base;
1582
1583   auto clientCtx = std::make_shared<folly::SSLContext>();
1584   auto serverCtx = std::make_shared<folly::SSLContext>();
1585   int fds[2];
1586   getfds(fds);
1587   getctx(clientCtx, serverCtx);
1588   auto client = AsyncSSLSocket::newSocket(
1589                   clientCtx, &base, fds[0], false, true);
1590   auto server = AsyncSSLSocket::newSocket(
1591                   serverCtx, &base, fds[1], true, true);
1592
1593   ReadCallbackTerminator readCallback(&base, nullptr);
1594   server->setReadCB(&readCallback);
1595   readCallback.setSocket(server);
1596
1597   uint8_t buf[128];
1598   memset(buf, 'a', sizeof(buf));
1599   client->write(nullptr, buf, sizeof(buf));
1600
1601   // Check that bytes are unencrypted
1602   char c;
1603   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1604   EXPECT_EQ('a', c);
1605
1606   EventBaseAborter eba(&base, 3000);
1607   base.loop();
1608
1609   EXPECT_EQ(1, readCallback.buffers.size());
1610   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1611
1612   server->setReadCB(&readCallback);
1613
1614   // Unencrypted
1615   server->sslAccept(nullptr);
1616   client->sslConn(nullptr);
1617
1618   // Do NOT wait for handshake, writing should be queued and happen after
1619
1620   client->write(nullptr, buf, sizeof(buf));
1621
1622   // Check that bytes are *not* unencrypted
1623   char c2;
1624   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1625   EXPECT_NE('a', c2);
1626
1627
1628   base.loop();
1629
1630   EXPECT_EQ(2, readCallback.buffers.size());
1631   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1632 }
1633
1634 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1635   // Start listening on a local port
1636   WriteCallbackBase writeCallback;
1637   WriteErrorCallback readCallback(&writeCallback);
1638   HandshakeCallback handshakeCallback(&readCallback,
1639                                       HandshakeCallback::EXPECT_ERROR);
1640   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1641   TestSSLServer server(&acceptCallback);
1642
1643   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1644   socket->open();
1645   uint8_t buf[3] = {0x16, 0x03, 0x01};
1646   socket->write(buf, sizeof(buf));
1647   socket->closeWithReset();
1648
1649   handshakeCallback.waitForHandshake();
1650   EXPECT_NE(
1651       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1652   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1653 }
1654
1655 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1656   // Start listening on a local port
1657   WriteCallbackBase writeCallback;
1658   WriteErrorCallback readCallback(&writeCallback);
1659   HandshakeCallback handshakeCallback(&readCallback,
1660                                       HandshakeCallback::EXPECT_ERROR);
1661   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1662   TestSSLServer server(&acceptCallback);
1663
1664   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1665   socket->open();
1666   uint8_t buf[3] = {0x16, 0x03, 0x01};
1667   socket->write(buf, sizeof(buf));
1668   socket->close();
1669
1670   handshakeCallback.waitForHandshake();
1671   EXPECT_NE(
1672       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1673   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1674 }
1675
1676 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1677   // Start listening on a local port
1678   WriteCallbackBase writeCallback;
1679   WriteErrorCallback readCallback(&writeCallback);
1680   HandshakeCallback handshakeCallback(&readCallback,
1681                                       HandshakeCallback::EXPECT_ERROR);
1682   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1683   TestSSLServer server(&acceptCallback);
1684
1685   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1686   socket->open();
1687   uint8_t buf[256] = {0x16, 0x03};
1688   memset(buf + 2, 'a', sizeof(buf) - 2);
1689   socket->write(buf, sizeof(buf));
1690   socket->close();
1691
1692   handshakeCallback.waitForHandshake();
1693   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1694             std::string::npos);
1695   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1696             std::string::npos);
1697 }
1698
1699 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
1700   using folly::ssl::OpenSSLUtils;
1701   EXPECT_EQ(
1702       OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
1703   // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
1704   EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
1705   // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
1706   EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
1707 }
1708
1709 #if FOLLY_ALLOW_TFO
1710
1711 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1712  public:
1713   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1714
1715   explicit MockAsyncTFOSSLSocket(
1716       std::shared_ptr<folly::SSLContext> sslCtx,
1717       EventBase* evb)
1718       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1719
1720   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1721 };
1722
1723 /**
1724  * Test connecting to, writing to, reading from, and closing the
1725  * connection to the SSL server with TFO.
1726  */
1727 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1728   // Start listening on a local port
1729   WriteCallbackBase writeCallback;
1730   ReadCallback readCallback(&writeCallback);
1731   HandshakeCallback handshakeCallback(&readCallback);
1732   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1733   TestSSLServer server(&acceptCallback, true);
1734
1735   // Set up SSL context.
1736   auto sslContext = std::make_shared<SSLContext>();
1737
1738   // connect
1739   auto socket =
1740       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1741   socket->enableTFO();
1742   socket->open();
1743
1744   // write()
1745   std::array<uint8_t, 128> buf;
1746   memset(buf.data(), 'a', buf.size());
1747   socket->write(buf.data(), buf.size());
1748
1749   // read()
1750   std::array<uint8_t, 128> readbuf;
1751   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1752   EXPECT_EQ(bytesRead, 128);
1753   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1754
1755   // close()
1756   socket->close();
1757 }
1758
1759 /**
1760  * Test connecting to, writing to, reading from, and closing the
1761  * connection to the SSL server with TFO.
1762  */
1763 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1764   // Start listening on a local port
1765   WriteCallbackBase writeCallback;
1766   ReadCallback readCallback(&writeCallback);
1767   HandshakeCallback handshakeCallback(&readCallback);
1768   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1769   TestSSLServer server(&acceptCallback, false);
1770
1771   // Set up SSL context.
1772   auto sslContext = std::make_shared<SSLContext>();
1773
1774   // connect
1775   auto socket =
1776       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1777   socket->enableTFO();
1778   socket->open();
1779
1780   // write()
1781   std::array<uint8_t, 128> buf;
1782   memset(buf.data(), 'a', buf.size());
1783   socket->write(buf.data(), buf.size());
1784
1785   // read()
1786   std::array<uint8_t, 128> readbuf;
1787   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1788   EXPECT_EQ(bytesRead, 128);
1789   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1790
1791   // close()
1792   socket->close();
1793 }
1794
1795 class ConnCallback : public AsyncSocket::ConnectCallback {
1796  public:
1797   virtual void connectSuccess() noexcept override {
1798     state = State::SUCCESS;
1799   }
1800
1801   virtual void connectErr(const AsyncSocketException& ex) noexcept override {
1802     state = State::ERROR;
1803     error = ex.what();
1804   }
1805
1806   enum class State { WAITING, SUCCESS, ERROR };
1807
1808   State state{State::WAITING};
1809   std::string error;
1810 };
1811
1812 template <class Cardinality>
1813 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1814     EventBase* evb,
1815     const SocketAddress& address,
1816     Cardinality cardinality) {
1817   // Set up SSL context.
1818   auto sslContext = std::make_shared<SSLContext>();
1819
1820   // connect
1821   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1822       new MockAsyncTFOSSLSocket(sslContext, evb));
1823   socket->enableTFO();
1824
1825   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1826       .Times(cardinality)
1827       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1828         sockaddr_storage addr;
1829         auto len = address.getAddress(&addr);
1830         return connect(fd, (const struct sockaddr*)&addr, len);
1831       }));
1832   return socket;
1833 }
1834
1835 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1836   // Start listening on a local port
1837   WriteCallbackBase writeCallback;
1838   ReadCallback readCallback(&writeCallback);
1839   HandshakeCallback handshakeCallback(&readCallback);
1840   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1841   TestSSLServer server(&acceptCallback, true);
1842
1843   EventBase evb;
1844
1845   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1846   ConnCallback ccb;
1847   socket->connect(&ccb, server.getAddress(), 30);
1848
1849   evb.loop();
1850   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1851
1852   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1853   evb.loop();
1854
1855   BlockingSocket sock(std::move(socket));
1856   // write()
1857   std::array<uint8_t, 128> buf;
1858   memset(buf.data(), 'a', buf.size());
1859   sock.write(buf.data(), buf.size());
1860
1861   // read()
1862   std::array<uint8_t, 128> readbuf;
1863   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1864   EXPECT_EQ(bytesRead, 128);
1865   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1866
1867   // close()
1868   sock.close();
1869 }
1870
1871 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1872   // Start listening on a local port
1873   ConnectTimeoutCallback acceptCallback;
1874   TestSSLServer server(&acceptCallback, true);
1875
1876   // Set up SSL context.
1877   auto sslContext = std::make_shared<SSLContext>();
1878
1879   // connect
1880   auto socket =
1881       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1882   socket->enableTFO();
1883   EXPECT_THROW(
1884       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
1885 }
1886
1887 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1888   // Start listening on a local port
1889   ConnectTimeoutCallback acceptCallback;
1890   TestSSLServer server(&acceptCallback, true);
1891
1892   EventBase evb;
1893
1894   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1895   ConnCallback ccb;
1896   // Set a short timeout
1897   socket->connect(&ccb, server.getAddress(), 1);
1898
1899   evb.loop();
1900   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1901 }
1902
1903 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
1904   // Start listening on a local port
1905   EmptyReadCallback readCallback;
1906   HandshakeCallback handshakeCallback(
1907       &readCallback, HandshakeCallback::EXPECT_ERROR);
1908   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
1909   TestSSLServer server(&acceptCallback, true);
1910
1911   EventBase evb;
1912
1913   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1914   ConnCallback ccb;
1915   socket->connect(&ccb, server.getAddress(), 100);
1916
1917   evb.loop();
1918   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1919   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
1920 }
1921
1922 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
1923   // Start listening on a local port
1924   EventBase evb;
1925
1926   // Hopefully nothing is listening on this address
1927   SocketAddress addr("127.0.0.1", 65535);
1928   auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
1929   ConnCallback ccb;
1930   socket->connect(&ccb, addr, 100);
1931
1932   evb.loop();
1933   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1934   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
1935 }
1936
1937 #endif
1938
1939 } // namespace
1940
1941 #ifdef SIGPIPE
1942 ///////////////////////////////////////////////////////////////////////////
1943 // init_unit_test_suite
1944 ///////////////////////////////////////////////////////////////////////////
1945 namespace {
1946 struct Initializer {
1947   Initializer() {
1948     signal(SIGPIPE, SIG_IGN);
1949   }
1950 };
1951 Initializer initializer;
1952 } // anonymous
1953 #endif