Switch uses of networking headers to <folly/portability/Sockets.h>
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/SocketAddress.h>
24 #include <folly/portability/Sockets.h>
25 #include <folly/portability/Unistd.h>
26
27 #include <folly/io/async/test/BlockingSocket.h>
28
29 #include <fstream>
30 #include <gtest/gtest.h>
31 #include <iostream>
32 #include <list>
33 #include <set>
34 #include <fcntl.h>
35 #include <openssl/bio.h>
36 #include <sys/types.h>
37 #include <folly/io/Cursor.h>
38
39 using std::string;
40 using std::vector;
41 using std::min;
42 using std::cerr;
43 using std::endl;
44 using std::list;
45
46 namespace folly {
47 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
48 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
49 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
50
51 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
52 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
53 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
54
55 constexpr size_t SSLClient::kMaxReadBufferSz;
56 constexpr size_t SSLClient::kMaxReadsPerEvent;
57
58 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
59     : ctx_(new folly::SSLContext),
60       acb_(acb),
61       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
62   // Set up the SSL context
63   ctx_->loadCertificate(testCert);
64   ctx_->loadPrivateKey(testKey);
65   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
66
67   acb_->ctx_ = ctx_;
68   acb_->base_ = &evb_;
69
70   //set up the listening socket
71   socket_->bind(0);
72   socket_->getAddress(&address_);
73   socket_->listen(100);
74   socket_->addAcceptCallback(acb_, &evb_);
75   socket_->startAccepting();
76
77   int ret = pthread_create(&thread_, nullptr, Main, this);
78   assert(ret == 0);
79   (void)ret;
80
81   std::cerr << "Accepting connections on " << address_ << std::endl;
82 }
83
84 void getfds(int fds[2]) {
85   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
86     FAIL() << "failed to create socketpair: " << strerror(errno);
87   }
88   for (int idx = 0; idx < 2; ++idx) {
89     int flags = fcntl(fds[idx], F_GETFL, 0);
90     if (flags == -1) {
91       FAIL() << "failed to get flags for socket " << idx << ": "
92              << strerror(errno);
93     }
94     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
95       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
96              << strerror(errno);
97     }
98   }
99 }
100
101 void getctx(
102   std::shared_ptr<folly::SSLContext> clientCtx,
103   std::shared_ptr<folly::SSLContext> serverCtx) {
104   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
105
106   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
107   serverCtx->loadCertificate(
108       testCert);
109   serverCtx->loadPrivateKey(
110       testKey);
111 }
112
113 void sslsocketpair(
114   EventBase* eventBase,
115   AsyncSSLSocket::UniquePtr* clientSock,
116   AsyncSSLSocket::UniquePtr* serverSock) {
117   auto clientCtx = std::make_shared<folly::SSLContext>();
118   auto serverCtx = std::make_shared<folly::SSLContext>();
119   int fds[2];
120   getfds(fds);
121   getctx(clientCtx, serverCtx);
122   clientSock->reset(new AsyncSSLSocket(
123                       clientCtx, eventBase, fds[0], false));
124   serverSock->reset(new AsyncSSLSocket(
125                       serverCtx, eventBase, fds[1], true));
126
127   // (*clientSock)->setSendTimeout(100);
128   // (*serverSock)->setSendTimeout(100);
129 }
130
131 // client protocol filters
132 bool clientProtoFilterPickPony(unsigned char** client,
133   unsigned int* client_len, const unsigned char*, unsigned int ) {
134   //the protocol string in length prefixed byte string. the
135   //length byte is not included in the length
136   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
137   *client = p;
138   *client_len = 7;
139   return true;
140 }
141
142 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
143   const unsigned char*, unsigned int) {
144   return false;
145 }
146
147 std::string getFileAsBuf(const char* fileName) {
148   std::string buffer;
149   folly::readFile(fileName, buffer);
150   return buffer;
151 }
152
153 std::string getCommonName(X509* cert) {
154   X509_NAME* subject = X509_get_subject_name(cert);
155   std::string cn;
156   cn.resize(ub_common_name);
157   X509_NAME_get_text_by_NID(
158       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
159   return cn;
160 }
161
162 /**
163  * Test connecting to, writing to, reading from, and closing the
164  * connection to the SSL server.
165  */
166 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
167   // Start listening on a local port
168   WriteCallbackBase writeCallback;
169   ReadCallback readCallback(&writeCallback);
170   HandshakeCallback handshakeCallback(&readCallback);
171   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
172   TestSSLServer server(&acceptCallback);
173
174   // Set up SSL context.
175   std::shared_ptr<SSLContext> sslContext(new SSLContext());
176   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
177   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
178   //sslContext->authenticate(true, false);
179
180   // connect
181   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
182                                                  sslContext);
183   socket->open();
184
185   // write()
186   uint8_t buf[128];
187   memset(buf, 'a', sizeof(buf));
188   socket->write(buf, sizeof(buf));
189
190   // read()
191   uint8_t readbuf[128];
192   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
193   EXPECT_EQ(bytesRead, 128);
194   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
195
196   // close()
197   socket->close();
198
199   cerr << "ConnectWriteReadClose test completed" << endl;
200 }
201
202 /**
203  * Test reading after server close.
204  */
205 TEST(AsyncSSLSocketTest, ReadAfterClose) {
206   // Start listening on a local port
207   WriteCallbackBase writeCallback;
208   ReadEOFCallback readCallback(&writeCallback);
209   HandshakeCallback handshakeCallback(&readCallback);
210   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
211   auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
212
213   // Set up SSL context.
214   auto sslContext = std::make_shared<SSLContext>();
215   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
216
217   auto socket =
218       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
219   socket->open();
220
221   // This should trigger an EOF on the client.
222   auto evb = handshakeCallback.getSocket()->getEventBase();
223   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
224   std::array<uint8_t, 128> readbuf;
225   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
226   EXPECT_EQ(0, bytesRead);
227 }
228
229 /**
230  * Test bad renegotiation
231  */
232 TEST(AsyncSSLSocketTest, Renegotiate) {
233   EventBase eventBase;
234   auto clientCtx = std::make_shared<SSLContext>();
235   auto dfServerCtx = std::make_shared<SSLContext>();
236   std::array<int, 2> fds;
237   getfds(fds.data());
238   getctx(clientCtx, dfServerCtx);
239
240   AsyncSSLSocket::UniquePtr clientSock(
241       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
242   AsyncSSLSocket::UniquePtr serverSock(
243       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
244   SSLHandshakeClient client(std::move(clientSock), true, true);
245   RenegotiatingServer server(std::move(serverSock));
246
247   while (!client.handshakeSuccess_ && !client.handshakeError_) {
248     eventBase.loopOnce();
249   }
250
251   ASSERT_TRUE(client.handshakeSuccess_);
252
253   auto sslSock = std::move(client).moveSocket();
254   sslSock->detachEventBase();
255   // This is nasty, however we don't want to add support for
256   // renegotiation in AsyncSSLSocket.
257   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
258
259   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
260
261   std::thread t([&]() { eventBase.loopForever(); });
262
263   // Trigger the renegotiation.
264   std::array<uint8_t, 128> buf;
265   memset(buf.data(), 'a', buf.size());
266   try {
267     socket->write(buf.data(), buf.size());
268   } catch (AsyncSocketException& e) {
269     LOG(INFO) << "client got error " << e.what();
270   }
271   eventBase.terminateLoopSoon();
272   t.join();
273
274   eventBase.loop();
275   ASSERT_TRUE(server.renegotiationError_);
276 }
277
278 /**
279  * Negative test for handshakeError().
280  */
281 TEST(AsyncSSLSocketTest, HandshakeError) {
282   // Start listening on a local port
283   WriteCallbackBase writeCallback;
284   WriteErrorCallback readCallback(&writeCallback);
285   HandshakeCallback handshakeCallback(&readCallback);
286   HandshakeErrorCallback acceptCallback(&handshakeCallback);
287   TestSSLServer server(&acceptCallback);
288
289   // Set up SSL context.
290   std::shared_ptr<SSLContext> sslContext(new SSLContext());
291   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
292
293   // connect
294   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
295                                                  sslContext);
296   // read()
297   bool ex = false;
298   try {
299     socket->open();
300
301     uint8_t readbuf[128];
302     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
303     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
304   } catch (AsyncSocketException &e) {
305     ex = true;
306   }
307   EXPECT_TRUE(ex);
308
309   // close()
310   socket->close();
311   cerr << "HandshakeError test completed" << endl;
312 }
313
314 /**
315  * Negative test for readError().
316  */
317 TEST(AsyncSSLSocketTest, ReadError) {
318   // Start listening on a local port
319   WriteCallbackBase writeCallback;
320   ReadErrorCallback readCallback(&writeCallback);
321   HandshakeCallback handshakeCallback(&readCallback);
322   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
323   TestSSLServer server(&acceptCallback);
324
325   // Set up SSL context.
326   std::shared_ptr<SSLContext> sslContext(new SSLContext());
327   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
328
329   // connect
330   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
331                                                  sslContext);
332   socket->open();
333
334   // write something to trigger ssl handshake
335   uint8_t buf[128];
336   memset(buf, 'a', sizeof(buf));
337   socket->write(buf, sizeof(buf));
338
339   socket->close();
340   cerr << "ReadError test completed" << endl;
341 }
342
343 /**
344  * Negative test for writeError().
345  */
346 TEST(AsyncSSLSocketTest, WriteError) {
347   // Start listening on a local port
348   WriteCallbackBase writeCallback;
349   WriteErrorCallback readCallback(&writeCallback);
350   HandshakeCallback handshakeCallback(&readCallback);
351   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
352   TestSSLServer server(&acceptCallback);
353
354   // Set up SSL context.
355   std::shared_ptr<SSLContext> sslContext(new SSLContext());
356   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
357
358   // connect
359   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
360                                                  sslContext);
361   socket->open();
362
363   // write something to trigger ssl handshake
364   uint8_t buf[128];
365   memset(buf, 'a', sizeof(buf));
366   socket->write(buf, sizeof(buf));
367
368   socket->close();
369   cerr << "WriteError test completed" << endl;
370 }
371
372 /**
373  * Test a socket with TCP_NODELAY unset.
374  */
375 TEST(AsyncSSLSocketTest, SocketWithDelay) {
376   // Start listening on a local port
377   WriteCallbackBase writeCallback;
378   ReadCallback readCallback(&writeCallback);
379   HandshakeCallback handshakeCallback(&readCallback);
380   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
381   TestSSLServer server(&acceptCallback);
382
383   // Set up SSL context.
384   std::shared_ptr<SSLContext> sslContext(new SSLContext());
385   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
386
387   // connect
388   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
389                                                  sslContext);
390   socket->open();
391
392   // write()
393   uint8_t buf[128];
394   memset(buf, 'a', sizeof(buf));
395   socket->write(buf, sizeof(buf));
396
397   // read()
398   uint8_t readbuf[128];
399   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
400   EXPECT_EQ(bytesRead, 128);
401   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
402
403   // close()
404   socket->close();
405
406   cerr << "SocketWithDelay test completed" << endl;
407 }
408
409 using NextProtocolTypePair =
410     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
411
412 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
413   // For matching protos
414  public:
415   void SetUp() override { getctx(clientCtx, serverCtx); }
416
417   void connect(bool unset = false) {
418     getfds(fds);
419
420     if (unset) {
421       // unsetting NPN for any of [client, server] is enough to make NPN not
422       // work
423       clientCtx->unsetNextProtocols();
424     }
425
426     AsyncSSLSocket::UniquePtr clientSock(
427       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
428     AsyncSSLSocket::UniquePtr serverSock(
429       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
430     client = folly::make_unique<NpnClient>(std::move(clientSock));
431     server = folly::make_unique<NpnServer>(std::move(serverSock));
432
433     eventBase.loop();
434   }
435
436   void expectProtocol(const std::string& proto) {
437     EXPECT_NE(client->nextProtoLength, 0);
438     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
439     EXPECT_EQ(
440         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
441         0);
442     string selected((const char*)client->nextProto, client->nextProtoLength);
443     EXPECT_EQ(proto, selected);
444   }
445
446   void expectNoProtocol() {
447     EXPECT_EQ(client->nextProtoLength, 0);
448     EXPECT_EQ(server->nextProtoLength, 0);
449     EXPECT_EQ(client->nextProto, nullptr);
450     EXPECT_EQ(server->nextProto, nullptr);
451   }
452
453   void expectProtocolType() {
454     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
455         GetParam().second == SSLContext::NextProtocolType::ANY) {
456       EXPECT_EQ(client->protocolType, server->protocolType);
457     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
458                GetParam().second == SSLContext::NextProtocolType::ANY) {
459       // Well not much we can say
460     } else {
461       expectProtocolType(GetParam());
462     }
463   }
464
465   void expectProtocolType(NextProtocolTypePair expected) {
466     EXPECT_EQ(client->protocolType, expected.first);
467     EXPECT_EQ(server->protocolType, expected.second);
468   }
469
470   EventBase eventBase;
471   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
472   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
473   int fds[2];
474   std::unique_ptr<NpnClient> client;
475   std::unique_ptr<NpnServer> server;
476 };
477
478 class NextProtocolNPNOnlyTest : public NextProtocolTest {
479   // For mismatching protos
480 };
481
482 class NextProtocolMismatchTest : public NextProtocolTest {
483   // For mismatching protos
484 };
485
486 TEST_P(NextProtocolTest, NpnTestOverlap) {
487   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
488   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
489                                         GetParam().second);
490
491   connect();
492
493   expectProtocol("baz");
494   expectProtocolType();
495 }
496
497 TEST_P(NextProtocolTest, NpnTestUnset) {
498   // Identical to above test, except that we want unset NPN before
499   // looping.
500   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
501   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
502                                         GetParam().second);
503
504   connect(true /* unset */);
505
506   // if alpn negotiation fails, type will appear as npn
507   expectNoProtocol();
508   EXPECT_EQ(client->protocolType, server->protocolType);
509 }
510
511 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
512   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
513   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
514                                         GetParam().second);
515
516   connect();
517
518   expectNoProtocol();
519   expectProtocolType(
520       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
521 }
522
523 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
524 // will fail on 1.0.2 before that.
525 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
526   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
527   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
528                                         GetParam().second);
529
530   connect();
531
532   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
533       GetParam().second == SSLContext::NextProtocolType::ALPN) {
534     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
535     // mismatch should result in a fatal alert, but this is OpenSSL's current
536     // behavior and we want to know if it changes.
537     expectNoProtocol();
538   } else {
539     expectProtocol("blub");
540     expectProtocolType(
541         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
542   }
543 }
544
545 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
546   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
547   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
548   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
549                                         GetParam().second);
550
551   connect();
552
553   expectProtocol("ponies");
554   expectProtocolType();
555 }
556
557 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
558   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
559   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
560   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
561                                         GetParam().second);
562
563   connect();
564
565   expectProtocol("blub");
566   expectProtocolType();
567 }
568
569 TEST_P(NextProtocolTest, RandomizedNpnTest) {
570   // Probability that this test will fail is 2^-64, which could be considered
571   // as negligible.
572   const int kTries = 64;
573
574   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
575                                         GetParam().first);
576   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
577                                                   GetParam().second);
578
579   std::set<string> selectedProtocols;
580   for (int i = 0; i < kTries; ++i) {
581     connect();
582
583     EXPECT_NE(client->nextProtoLength, 0);
584     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
585     EXPECT_EQ(
586         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
587         0);
588     string selected((const char*)client->nextProto, client->nextProtoLength);
589     selectedProtocols.insert(selected);
590     expectProtocolType();
591   }
592   EXPECT_EQ(selectedProtocols.size(), 2);
593 }
594
595 INSTANTIATE_TEST_CASE_P(
596     AsyncSSLSocketTest,
597     NextProtocolTest,
598     ::testing::Values(
599         NextProtocolTypePair(
600             SSLContext::NextProtocolType::NPN,
601             SSLContext::NextProtocolType::NPN),
602 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
603         NextProtocolTypePair(
604             SSLContext::NextProtocolType::ALPN,
605             SSLContext::NextProtocolType::ALPN),
606         NextProtocolTypePair(
607             SSLContext::NextProtocolType::ALPN,
608             SSLContext::NextProtocolType::ANY),
609         NextProtocolTypePair(
610             SSLContext::NextProtocolType::ANY,
611             SSLContext::NextProtocolType::ALPN),
612 #endif
613         NextProtocolTypePair(
614             SSLContext::NextProtocolType::NPN,
615             SSLContext::NextProtocolType::ANY),
616         NextProtocolTypePair(
617             SSLContext::NextProtocolType::ANY,
618             SSLContext::NextProtocolType::ANY)));
619
620 INSTANTIATE_TEST_CASE_P(
621     AsyncSSLSocketTest,
622     NextProtocolNPNOnlyTest,
623     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
624                                            SSLContext::NextProtocolType::NPN)));
625
626 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
627 INSTANTIATE_TEST_CASE_P(
628     AsyncSSLSocketTest,
629     NextProtocolMismatchTest,
630     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
631                                            SSLContext::NextProtocolType::ALPN),
632                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
633                                            SSLContext::NextProtocolType::NPN)));
634 #endif
635
636 #ifndef OPENSSL_NO_TLSEXT
637 /**
638  * 1. Client sends TLSEXT_HOSTNAME in client hello.
639  * 2. Server found a match SSL_CTX and use this SSL_CTX to
640  *    continue the SSL handshake.
641  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
642  */
643 TEST(AsyncSSLSocketTest, SNITestMatch) {
644   EventBase eventBase;
645   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
646   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
647   // Use the same SSLContext to continue the handshake after
648   // tlsext_hostname match.
649   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
650   const std::string serverName("xyz.newdev.facebook.com");
651   int fds[2];
652   getfds(fds);
653   getctx(clientCtx, dfServerCtx);
654
655   AsyncSSLSocket::UniquePtr clientSock(
656     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
657   AsyncSSLSocket::UniquePtr serverSock(
658     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
659   SNIClient client(std::move(clientSock));
660   SNIServer server(std::move(serverSock),
661                    dfServerCtx,
662                    hskServerCtx,
663                    serverName);
664
665   eventBase.loop();
666
667   EXPECT_TRUE(client.serverNameMatch);
668   EXPECT_TRUE(server.serverNameMatch);
669 }
670
671 /**
672  * 1. Client sends TLSEXT_HOSTNAME in client hello.
673  * 2. Server cannot find a matching SSL_CTX and continue to use
674  *    the current SSL_CTX to do the handshake.
675  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
676  */
677 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
678   EventBase eventBase;
679   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
680   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
681   // Use the same SSLContext to continue the handshake after
682   // tlsext_hostname match.
683   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
684   const std::string clientRequestingServerName("foo.com");
685   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
686
687   int fds[2];
688   getfds(fds);
689   getctx(clientCtx, dfServerCtx);
690
691   AsyncSSLSocket::UniquePtr clientSock(
692     new AsyncSSLSocket(clientCtx,
693                         &eventBase,
694                         fds[0],
695                         clientRequestingServerName));
696   AsyncSSLSocket::UniquePtr serverSock(
697     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
698   SNIClient client(std::move(clientSock));
699   SNIServer server(std::move(serverSock),
700                    dfServerCtx,
701                    hskServerCtx,
702                    serverExpectedServerName);
703
704   eventBase.loop();
705
706   EXPECT_TRUE(!client.serverNameMatch);
707   EXPECT_TRUE(!server.serverNameMatch);
708 }
709 /**
710  * 1. Client sends TLSEXT_HOSTNAME in client hello.
711  * 2. We then change the serverName.
712  * 3. We expect that we get 'false' as the result for serNameMatch.
713  */
714
715 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
716    EventBase eventBase;
717   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
718   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
719   // Use the same SSLContext to continue the handshake after
720   // tlsext_hostname match.
721   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
722   const std::string serverName("xyz.newdev.facebook.com");
723   int fds[2];
724   getfds(fds);
725   getctx(clientCtx, dfServerCtx);
726
727   AsyncSSLSocket::UniquePtr clientSock(
728     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
729   //Change the server name
730   std::string newName("new.com");
731   clientSock->setServerName(newName);
732   AsyncSSLSocket::UniquePtr serverSock(
733     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
734   SNIClient client(std::move(clientSock));
735   SNIServer server(std::move(serverSock),
736                    dfServerCtx,
737                    hskServerCtx,
738                    serverName);
739
740   eventBase.loop();
741
742   EXPECT_TRUE(!client.serverNameMatch);
743 }
744
745 /**
746  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
747  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
748  */
749 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
750   EventBase eventBase;
751   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
752   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
753   // Use the same SSLContext to continue the handshake after
754   // tlsext_hostname match.
755   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
756   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
757
758   int fds[2];
759   getfds(fds);
760   getctx(clientCtx, dfServerCtx);
761
762   AsyncSSLSocket::UniquePtr clientSock(
763     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
764   AsyncSSLSocket::UniquePtr serverSock(
765     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
766   SNIClient client(std::move(clientSock));
767   SNIServer server(std::move(serverSock),
768                    dfServerCtx,
769                    hskServerCtx,
770                    serverExpectedServerName);
771
772   eventBase.loop();
773
774   EXPECT_TRUE(!client.serverNameMatch);
775   EXPECT_TRUE(!server.serverNameMatch);
776 }
777
778 #endif
779 /**
780  * Test SSL client socket
781  */
782 TEST(AsyncSSLSocketTest, SSLClientTest) {
783   // Start listening on a local port
784   WriteCallbackBase writeCallback;
785   ReadCallback readCallback(&writeCallback);
786   HandshakeCallback handshakeCallback(&readCallback);
787   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
788   TestSSLServer server(&acceptCallback);
789
790   // Set up SSL client
791   EventBase eventBase;
792   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
793
794   client->connect();
795   EventBaseAborter eba(&eventBase, 3000);
796   eventBase.loop();
797
798   EXPECT_EQ(client->getMiss(), 1);
799   EXPECT_EQ(client->getHit(), 0);
800
801   cerr << "SSLClientTest test completed" << endl;
802 }
803
804
805 /**
806  * Test SSL client socket session re-use
807  */
808 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
809   // Start listening on a local port
810   WriteCallbackBase writeCallback;
811   ReadCallback readCallback(&writeCallback);
812   HandshakeCallback handshakeCallback(&readCallback);
813   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
814   TestSSLServer server(&acceptCallback);
815
816   // Set up SSL client
817   EventBase eventBase;
818   auto client =
819       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
820
821   client->connect();
822   EventBaseAborter eba(&eventBase, 3000);
823   eventBase.loop();
824
825   EXPECT_EQ(client->getMiss(), 1);
826   EXPECT_EQ(client->getHit(), 9);
827
828   cerr << "SSLClientTestReuse test completed" << endl;
829 }
830
831 /**
832  * Test SSL client socket timeout
833  */
834 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
835   // Start listening on a local port
836   EmptyReadCallback readCallback;
837   HandshakeCallback handshakeCallback(&readCallback,
838                                       HandshakeCallback::EXPECT_ERROR);
839   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
840   TestSSLServer server(&acceptCallback);
841
842   // Set up SSL client
843   EventBase eventBase;
844   auto client =
845       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
846   client->connect(true /* write before connect completes */);
847   EventBaseAborter eba(&eventBase, 3000);
848   eventBase.loop();
849
850   usleep(100000);
851   // This is checking that the connectError callback precedes any queued
852   // writeError callbacks.  This matches AsyncSocket's behavior
853   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
854   EXPECT_EQ(client->getErrors(), 1);
855   EXPECT_EQ(client->getMiss(), 0);
856   EXPECT_EQ(client->getHit(), 0);
857
858   cerr << "SSLClientTimeoutTest test completed" << endl;
859 }
860
861
862 /**
863  * Test SSL server async cache
864  */
865 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
866   // Start listening on a local port
867   WriteCallbackBase writeCallback;
868   ReadCallback readCallback(&writeCallback);
869   HandshakeCallback handshakeCallback(&readCallback);
870   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
871   TestSSLAsyncCacheServer server(&acceptCallback);
872
873   // Set up SSL client
874   EventBase eventBase;
875   auto client =
876       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
877
878   client->connect();
879   EventBaseAborter eba(&eventBase, 3000);
880   eventBase.loop();
881
882   EXPECT_EQ(server.getAsyncCallbacks(), 18);
883   EXPECT_EQ(server.getAsyncLookups(), 9);
884   EXPECT_EQ(client->getMiss(), 10);
885   EXPECT_EQ(client->getHit(), 0);
886
887   cerr << "SSLServerAsyncCacheTest test completed" << endl;
888 }
889
890
891 /**
892  * Test SSL server accept timeout with cache path
893  */
894 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
895   // Start listening on a local port
896   WriteCallbackBase writeCallback;
897   ReadCallback readCallback(&writeCallback);
898   EmptyReadCallback clientReadCallback;
899   HandshakeCallback handshakeCallback(&readCallback);
900   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
901   TestSSLAsyncCacheServer server(&acceptCallback);
902
903   // Set up SSL client
904   EventBase eventBase;
905   // only do a TCP connect
906   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
907   sock->connect(nullptr, server.getAddress());
908   clientReadCallback.tcpSocket_ = sock;
909   sock->setReadCB(&clientReadCallback);
910
911   EventBaseAborter eba(&eventBase, 3000);
912   eventBase.loop();
913
914   EXPECT_EQ(readCallback.state, STATE_WAITING);
915
916   cerr << "SSLServerTimeoutTest test completed" << endl;
917 }
918
919 /**
920  * Test SSL server accept timeout with cache path
921  */
922 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
923   // Start listening on a local port
924   WriteCallbackBase writeCallback;
925   ReadCallback readCallback(&writeCallback);
926   HandshakeCallback handshakeCallback(&readCallback);
927   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
928   TestSSLAsyncCacheServer server(&acceptCallback);
929
930   // Set up SSL client
931   EventBase eventBase;
932   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
933
934   client->connect();
935   EventBaseAborter eba(&eventBase, 3000);
936   eventBase.loop();
937
938   EXPECT_EQ(server.getAsyncCallbacks(), 1);
939   EXPECT_EQ(server.getAsyncLookups(), 1);
940   EXPECT_EQ(client->getErrors(), 1);
941   EXPECT_EQ(client->getMiss(), 1);
942   EXPECT_EQ(client->getHit(), 0);
943
944   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
945 }
946
947 /**
948  * Test SSL server accept timeout with cache path
949  */
950 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
951   // Start listening on a local port
952   WriteCallbackBase writeCallback;
953   ReadCallback readCallback(&writeCallback);
954   HandshakeCallback handshakeCallback(&readCallback,
955                                       HandshakeCallback::EXPECT_ERROR);
956   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
957   TestSSLAsyncCacheServer server(&acceptCallback, 500);
958
959   // Set up SSL client
960   EventBase eventBase;
961   auto client =
962       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
963
964   client->connect();
965   EventBaseAborter eba(&eventBase, 3000);
966   eventBase.loop();
967
968   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
969       handshakeCallback.closeSocket();});
970   // give time for the cache lookup to come back and find it closed
971   handshakeCallback.waitForHandshake();
972
973   EXPECT_EQ(server.getAsyncCallbacks(), 1);
974   EXPECT_EQ(server.getAsyncLookups(), 1);
975   EXPECT_EQ(client->getErrors(), 1);
976   EXPECT_EQ(client->getMiss(), 1);
977   EXPECT_EQ(client->getHit(), 0);
978
979   cerr << "SSLServerCacheCloseTest test completed" << endl;
980 }
981
982 /**
983  * Verify Client Ciphers obtained using SSL MSG Callback.
984  */
985 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
986   EventBase eventBase;
987   auto clientCtx = std::make_shared<SSLContext>();
988   auto serverCtx = std::make_shared<SSLContext>();
989   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
990   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
991   serverCtx->loadPrivateKey(testKey);
992   serverCtx->loadCertificate(testCert);
993   serverCtx->loadTrustedCertificates(testCA);
994   serverCtx->loadClientCAList(testCA);
995
996   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
997   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
998   clientCtx->loadPrivateKey(testKey);
999   clientCtx->loadCertificate(testCert);
1000   clientCtx->loadTrustedCertificates(testCA);
1001
1002   int fds[2];
1003   getfds(fds);
1004
1005   AsyncSSLSocket::UniquePtr clientSock(
1006       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1007   AsyncSSLSocket::UniquePtr serverSock(
1008       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1009
1010   SSLHandshakeClient client(std::move(clientSock), true, true);
1011   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1012
1013   eventBase.loop();
1014
1015   EXPECT_EQ(server.clientCiphers_,
1016             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
1017   EXPECT_TRUE(client.handshakeVerify_);
1018   EXPECT_TRUE(client.handshakeSuccess_);
1019   EXPECT_TRUE(!client.handshakeError_);
1020   EXPECT_TRUE(server.handshakeVerify_);
1021   EXPECT_TRUE(server.handshakeSuccess_);
1022   EXPECT_TRUE(!server.handshakeError_);
1023 }
1024
1025 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1026   EventBase eventBase;
1027   auto ctx = std::make_shared<SSLContext>();
1028
1029   int fds[2];
1030   getfds(fds);
1031
1032   int bufLen = 42;
1033   uint8_t majorVersion = 18;
1034   uint8_t minorVersion = 25;
1035
1036   // Create callback buf
1037   auto buf = IOBuf::create(bufLen);
1038   buf->append(bufLen);
1039   folly::io::RWPrivateCursor cursor(buf.get());
1040   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1041   cursor.write<uint16_t>(0);
1042   cursor.write<uint8_t>(38);
1043   cursor.write<uint8_t>(majorVersion);
1044   cursor.write<uint8_t>(minorVersion);
1045   cursor.skip(32);
1046   cursor.write<uint32_t>(0);
1047
1048   SSL* ssl = ctx->createSSL();
1049   SCOPE_EXIT { SSL_free(ssl); };
1050   AsyncSSLSocket::UniquePtr sock(
1051       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1052   sock->enableClientHelloParsing();
1053
1054   // Test client hello parsing in one packet
1055   AsyncSSLSocket::clientHelloParsingCallback(
1056       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1057   buf.reset();
1058
1059   auto parsedClientHello = sock->getClientHelloInfo();
1060   EXPECT_TRUE(parsedClientHello != nullptr);
1061   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1062   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1063 }
1064
1065 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1066   EventBase eventBase;
1067   auto ctx = std::make_shared<SSLContext>();
1068
1069   int fds[2];
1070   getfds(fds);
1071
1072   int bufLen = 42;
1073   uint8_t majorVersion = 18;
1074   uint8_t minorVersion = 25;
1075
1076   // Create callback buf
1077   auto buf = IOBuf::create(bufLen);
1078   buf->append(bufLen);
1079   folly::io::RWPrivateCursor cursor(buf.get());
1080   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1081   cursor.write<uint16_t>(0);
1082   cursor.write<uint8_t>(38);
1083   cursor.write<uint8_t>(majorVersion);
1084   cursor.write<uint8_t>(minorVersion);
1085   cursor.skip(32);
1086   cursor.write<uint32_t>(0);
1087
1088   SSL* ssl = ctx->createSSL();
1089   SCOPE_EXIT { SSL_free(ssl); };
1090   AsyncSSLSocket::UniquePtr sock(
1091       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1092   sock->enableClientHelloParsing();
1093
1094   // Test parsing with two packets with first packet size < 3
1095   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1096   AsyncSSLSocket::clientHelloParsingCallback(
1097       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1098       ssl, sock.get());
1099   bufCopy.reset();
1100   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1101   AsyncSSLSocket::clientHelloParsingCallback(
1102       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1103       ssl, sock.get());
1104   bufCopy.reset();
1105
1106   auto parsedClientHello = sock->getClientHelloInfo();
1107   EXPECT_TRUE(parsedClientHello != nullptr);
1108   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1109   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1110 }
1111
1112 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1113   EventBase eventBase;
1114   auto ctx = std::make_shared<SSLContext>();
1115
1116   int fds[2];
1117   getfds(fds);
1118
1119   int bufLen = 42;
1120   uint8_t majorVersion = 18;
1121   uint8_t minorVersion = 25;
1122
1123   // Create callback buf
1124   auto buf = IOBuf::create(bufLen);
1125   buf->append(bufLen);
1126   folly::io::RWPrivateCursor cursor(buf.get());
1127   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1128   cursor.write<uint16_t>(0);
1129   cursor.write<uint8_t>(38);
1130   cursor.write<uint8_t>(majorVersion);
1131   cursor.write<uint8_t>(minorVersion);
1132   cursor.skip(32);
1133   cursor.write<uint32_t>(0);
1134
1135   SSL* ssl = ctx->createSSL();
1136   SCOPE_EXIT { SSL_free(ssl); };
1137   AsyncSSLSocket::UniquePtr sock(
1138       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1139   sock->enableClientHelloParsing();
1140
1141   // Test parsing with multiple small packets
1142   for (uint64_t i = 0; i < buf->length(); i += 3) {
1143     auto bufCopy = folly::IOBuf::copyBuffer(
1144         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1145     AsyncSSLSocket::clientHelloParsingCallback(
1146         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1147         ssl, sock.get());
1148     bufCopy.reset();
1149   }
1150
1151   auto parsedClientHello = sock->getClientHelloInfo();
1152   EXPECT_TRUE(parsedClientHello != nullptr);
1153   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1154   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1155 }
1156
1157 /**
1158  * Verify sucessful behavior of SSL certificate validation.
1159  */
1160 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1161   EventBase eventBase;
1162   auto clientCtx = std::make_shared<SSLContext>();
1163   auto dfServerCtx = std::make_shared<SSLContext>();
1164
1165   int fds[2];
1166   getfds(fds);
1167   getctx(clientCtx, dfServerCtx);
1168
1169   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1170   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1171
1172   AsyncSSLSocket::UniquePtr clientSock(
1173     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1174   AsyncSSLSocket::UniquePtr serverSock(
1175     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1176
1177   SSLHandshakeClient client(std::move(clientSock), true, true);
1178   clientCtx->loadTrustedCertificates(testCA);
1179
1180   SSLHandshakeServer server(std::move(serverSock), true, true);
1181
1182   eventBase.loop();
1183
1184   EXPECT_TRUE(client.handshakeVerify_);
1185   EXPECT_TRUE(client.handshakeSuccess_);
1186   EXPECT_TRUE(!client.handshakeError_);
1187   EXPECT_LE(0, client.handshakeTime.count());
1188   EXPECT_TRUE(!server.handshakeVerify_);
1189   EXPECT_TRUE(server.handshakeSuccess_);
1190   EXPECT_TRUE(!server.handshakeError_);
1191   EXPECT_LE(0, server.handshakeTime.count());
1192 }
1193
1194 /**
1195  * Verify that the client's verification callback is able to fail SSL
1196  * connection establishment.
1197  */
1198 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1199   EventBase eventBase;
1200   auto clientCtx = std::make_shared<SSLContext>();
1201   auto dfServerCtx = std::make_shared<SSLContext>();
1202
1203   int fds[2];
1204   getfds(fds);
1205   getctx(clientCtx, dfServerCtx);
1206
1207   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1208   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1209
1210   AsyncSSLSocket::UniquePtr clientSock(
1211     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1212   AsyncSSLSocket::UniquePtr serverSock(
1213     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1214
1215   SSLHandshakeClient client(std::move(clientSock), true, false);
1216   clientCtx->loadTrustedCertificates(testCA);
1217
1218   SSLHandshakeServer server(std::move(serverSock), true, true);
1219
1220   eventBase.loop();
1221
1222   EXPECT_TRUE(client.handshakeVerify_);
1223   EXPECT_TRUE(!client.handshakeSuccess_);
1224   EXPECT_TRUE(client.handshakeError_);
1225   EXPECT_LE(0, client.handshakeTime.count());
1226   EXPECT_TRUE(!server.handshakeVerify_);
1227   EXPECT_TRUE(!server.handshakeSuccess_);
1228   EXPECT_TRUE(server.handshakeError_);
1229   EXPECT_LE(0, server.handshakeTime.count());
1230 }
1231
1232 /**
1233  * Verify that the options in SSLContext can be overridden in
1234  * sslConnect/Accept.i.e specifying that no validation should be performed
1235  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1236  * the validation callback.
1237  */
1238 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1239   EventBase eventBase;
1240   auto clientCtx = std::make_shared<SSLContext>();
1241   auto dfServerCtx = std::make_shared<SSLContext>();
1242
1243   int fds[2];
1244   getfds(fds);
1245   getctx(clientCtx, dfServerCtx);
1246
1247   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1248   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1249
1250   AsyncSSLSocket::UniquePtr clientSock(
1251     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1252   AsyncSSLSocket::UniquePtr serverSock(
1253     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1254
1255   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1256   clientCtx->loadTrustedCertificates(testCA);
1257
1258   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1259
1260   eventBase.loop();
1261
1262   EXPECT_TRUE(!client.handshakeVerify_);
1263   EXPECT_TRUE(client.handshakeSuccess_);
1264   EXPECT_TRUE(!client.handshakeError_);
1265   EXPECT_LE(0, client.handshakeTime.count());
1266   EXPECT_TRUE(!server.handshakeVerify_);
1267   EXPECT_TRUE(server.handshakeSuccess_);
1268   EXPECT_TRUE(!server.handshakeError_);
1269   EXPECT_LE(0, server.handshakeTime.count());
1270 }
1271
1272 /**
1273  * Verify that the options in SSLContext can be overridden in
1274  * sslConnect/Accept. Enable verification even if context says otherwise.
1275  * Test requireClientCert with client cert
1276  */
1277 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1278   EventBase eventBase;
1279   auto clientCtx = std::make_shared<SSLContext>();
1280   auto serverCtx = std::make_shared<SSLContext>();
1281   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1282   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1283   serverCtx->loadPrivateKey(testKey);
1284   serverCtx->loadCertificate(testCert);
1285   serverCtx->loadTrustedCertificates(testCA);
1286   serverCtx->loadClientCAList(testCA);
1287
1288   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1289   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1290   clientCtx->loadPrivateKey(testKey);
1291   clientCtx->loadCertificate(testCert);
1292   clientCtx->loadTrustedCertificates(testCA);
1293
1294   int fds[2];
1295   getfds(fds);
1296
1297   AsyncSSLSocket::UniquePtr clientSock(
1298       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1299   AsyncSSLSocket::UniquePtr serverSock(
1300       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1301
1302   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1303   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1304
1305   eventBase.loop();
1306
1307   EXPECT_TRUE(client.handshakeVerify_);
1308   EXPECT_TRUE(client.handshakeSuccess_);
1309   EXPECT_FALSE(client.handshakeError_);
1310   EXPECT_LE(0, client.handshakeTime.count());
1311   EXPECT_TRUE(server.handshakeVerify_);
1312   EXPECT_TRUE(server.handshakeSuccess_);
1313   EXPECT_FALSE(server.handshakeError_);
1314   EXPECT_LE(0, server.handshakeTime.count());
1315 }
1316
1317 /**
1318  * Verify that the client's verification callback is able to override
1319  * the preverification failure and allow a successful connection.
1320  */
1321 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1322   EventBase eventBase;
1323   auto clientCtx = std::make_shared<SSLContext>();
1324   auto dfServerCtx = std::make_shared<SSLContext>();
1325
1326   int fds[2];
1327   getfds(fds);
1328   getctx(clientCtx, dfServerCtx);
1329
1330   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1331   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1332
1333   AsyncSSLSocket::UniquePtr clientSock(
1334     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1335   AsyncSSLSocket::UniquePtr serverSock(
1336     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1337
1338   SSLHandshakeClient client(std::move(clientSock), false, true);
1339   SSLHandshakeServer server(std::move(serverSock), true, true);
1340
1341   eventBase.loop();
1342
1343   EXPECT_TRUE(client.handshakeVerify_);
1344   EXPECT_TRUE(client.handshakeSuccess_);
1345   EXPECT_TRUE(!client.handshakeError_);
1346   EXPECT_LE(0, client.handshakeTime.count());
1347   EXPECT_TRUE(!server.handshakeVerify_);
1348   EXPECT_TRUE(server.handshakeSuccess_);
1349   EXPECT_TRUE(!server.handshakeError_);
1350   EXPECT_LE(0, server.handshakeTime.count());
1351 }
1352
1353 /**
1354  * Verify that specifying that no validation should be performed allows an
1355  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1356  * callback.
1357  */
1358 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1359   EventBase eventBase;
1360   auto clientCtx = std::make_shared<SSLContext>();
1361   auto dfServerCtx = std::make_shared<SSLContext>();
1362
1363   int fds[2];
1364   getfds(fds);
1365   getctx(clientCtx, dfServerCtx);
1366
1367   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1368   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1369
1370   AsyncSSLSocket::UniquePtr clientSock(
1371     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1372   AsyncSSLSocket::UniquePtr serverSock(
1373     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1374
1375   SSLHandshakeClient client(std::move(clientSock), false, false);
1376   SSLHandshakeServer server(std::move(serverSock), false, false);
1377
1378   eventBase.loop();
1379
1380   EXPECT_TRUE(!client.handshakeVerify_);
1381   EXPECT_TRUE(client.handshakeSuccess_);
1382   EXPECT_TRUE(!client.handshakeError_);
1383   EXPECT_LE(0, client.handshakeTime.count());
1384   EXPECT_TRUE(!server.handshakeVerify_);
1385   EXPECT_TRUE(server.handshakeSuccess_);
1386   EXPECT_TRUE(!server.handshakeError_);
1387   EXPECT_LE(0, server.handshakeTime.count());
1388 }
1389
1390 /**
1391  * Test requireClientCert with client cert
1392  */
1393 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1394   EventBase eventBase;
1395   auto clientCtx = std::make_shared<SSLContext>();
1396   auto serverCtx = std::make_shared<SSLContext>();
1397   serverCtx->setVerificationOption(
1398       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1399   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1400   serverCtx->loadPrivateKey(testKey);
1401   serverCtx->loadCertificate(testCert);
1402   serverCtx->loadTrustedCertificates(testCA);
1403   serverCtx->loadClientCAList(testCA);
1404
1405   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1406   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1407   clientCtx->loadPrivateKey(testKey);
1408   clientCtx->loadCertificate(testCert);
1409   clientCtx->loadTrustedCertificates(testCA);
1410
1411   int fds[2];
1412   getfds(fds);
1413
1414   AsyncSSLSocket::UniquePtr clientSock(
1415       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1416   AsyncSSLSocket::UniquePtr serverSock(
1417       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1418
1419   SSLHandshakeClient client(std::move(clientSock), true, true);
1420   SSLHandshakeServer server(std::move(serverSock), true, true);
1421
1422   eventBase.loop();
1423
1424   EXPECT_TRUE(client.handshakeVerify_);
1425   EXPECT_TRUE(client.handshakeSuccess_);
1426   EXPECT_FALSE(client.handshakeError_);
1427   EXPECT_LE(0, client.handshakeTime.count());
1428   EXPECT_TRUE(server.handshakeVerify_);
1429   EXPECT_TRUE(server.handshakeSuccess_);
1430   EXPECT_FALSE(server.handshakeError_);
1431   EXPECT_LE(0, server.handshakeTime.count());
1432 }
1433
1434
1435 /**
1436  * Test requireClientCert with no client cert
1437  */
1438 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1439   EventBase eventBase;
1440   auto clientCtx = std::make_shared<SSLContext>();
1441   auto serverCtx = std::make_shared<SSLContext>();
1442   serverCtx->setVerificationOption(
1443       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1444   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1445   serverCtx->loadPrivateKey(testKey);
1446   serverCtx->loadCertificate(testCert);
1447   serverCtx->loadTrustedCertificates(testCA);
1448   serverCtx->loadClientCAList(testCA);
1449   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1450   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1451
1452   int fds[2];
1453   getfds(fds);
1454
1455   AsyncSSLSocket::UniquePtr clientSock(
1456       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1457   AsyncSSLSocket::UniquePtr serverSock(
1458       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1459
1460   SSLHandshakeClient client(std::move(clientSock), false, false);
1461   SSLHandshakeServer server(std::move(serverSock), false, false);
1462
1463   eventBase.loop();
1464
1465   EXPECT_FALSE(server.handshakeVerify_);
1466   EXPECT_FALSE(server.handshakeSuccess_);
1467   EXPECT_TRUE(server.handshakeError_);
1468   EXPECT_LE(0, client.handshakeTime.count());
1469   EXPECT_LE(0, server.handshakeTime.count());
1470 }
1471
1472 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1473   auto cert = getFileAsBuf(testCert);
1474   auto key = getFileAsBuf(testKey);
1475
1476   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1477   BIO_write(certBio.get(), cert.data(), cert.size());
1478   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1479   BIO_write(keyBio.get(), key.data(), key.size());
1480
1481   // Create SSL structs from buffers to get properties
1482   ssl::X509UniquePtr certStruct(
1483       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1484   ssl::EvpPkeyUniquePtr keyStruct(
1485       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1486   certBio = nullptr;
1487   keyBio = nullptr;
1488
1489   auto origCommonName = getCommonName(certStruct.get());
1490   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1491   certStruct = nullptr;
1492   keyStruct = nullptr;
1493
1494   auto ctx = std::make_shared<SSLContext>();
1495   ctx->loadPrivateKeyFromBufferPEM(key);
1496   ctx->loadCertificateFromBufferPEM(cert);
1497   ctx->loadTrustedCertificates(testCA);
1498
1499   ssl::SSLUniquePtr ssl(ctx->createSSL());
1500
1501   auto newCert = SSL_get_certificate(ssl.get());
1502   auto newKey = SSL_get_privatekey(ssl.get());
1503
1504   // Get properties from SSL struct
1505   auto newCommonName = getCommonName(newCert);
1506   auto newKeySize = EVP_PKEY_bits(newKey);
1507
1508   // Check that the key and cert have the expected properties
1509   EXPECT_EQ(origCommonName, newCommonName);
1510   EXPECT_EQ(origKeySize, newKeySize);
1511 }
1512
1513 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1514   EventBase eb;
1515
1516   // Set up SSL context.
1517   auto sslContext = std::make_shared<SSLContext>();
1518   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1519
1520   // create SSL socket
1521   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1522
1523   EXPECT_EQ(1500, socket->getMinWriteSize());
1524
1525   socket->setMinWriteSize(0);
1526   EXPECT_EQ(0, socket->getMinWriteSize());
1527   socket->setMinWriteSize(50000);
1528   EXPECT_EQ(50000, socket->getMinWriteSize());
1529 }
1530
1531 class ReadCallbackTerminator : public ReadCallback {
1532  public:
1533   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1534       : ReadCallback(wcb)
1535       , base_(base) {}
1536
1537   // Do not write data back, terminate the loop.
1538   void readDataAvailable(size_t len) noexcept override {
1539     std::cerr << "readDataAvailable, len " << len << std::endl;
1540
1541     currentBuffer.length = len;
1542
1543     buffers.push_back(currentBuffer);
1544     currentBuffer.reset();
1545     state = STATE_SUCCEEDED;
1546
1547     socket_->setReadCB(nullptr);
1548     base_->terminateLoopSoon();
1549   }
1550  private:
1551   EventBase* base_;
1552 };
1553
1554
1555 /**
1556  * Test a full unencrypted codepath
1557  */
1558 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1559   EventBase base;
1560
1561   auto clientCtx = std::make_shared<folly::SSLContext>();
1562   auto serverCtx = std::make_shared<folly::SSLContext>();
1563   int fds[2];
1564   getfds(fds);
1565   getctx(clientCtx, serverCtx);
1566   auto client = AsyncSSLSocket::newSocket(
1567                   clientCtx, &base, fds[0], false, true);
1568   auto server = AsyncSSLSocket::newSocket(
1569                   serverCtx, &base, fds[1], true, true);
1570
1571   ReadCallbackTerminator readCallback(&base, nullptr);
1572   server->setReadCB(&readCallback);
1573   readCallback.setSocket(server);
1574
1575   uint8_t buf[128];
1576   memset(buf, 'a', sizeof(buf));
1577   client->write(nullptr, buf, sizeof(buf));
1578
1579   // Check that bytes are unencrypted
1580   char c;
1581   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1582   EXPECT_EQ('a', c);
1583
1584   EventBaseAborter eba(&base, 3000);
1585   base.loop();
1586
1587   EXPECT_EQ(1, readCallback.buffers.size());
1588   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1589
1590   server->setReadCB(&readCallback);
1591
1592   // Unencrypted
1593   server->sslAccept(nullptr);
1594   client->sslConn(nullptr);
1595
1596   // Do NOT wait for handshake, writing should be queued and happen after
1597
1598   client->write(nullptr, buf, sizeof(buf));
1599
1600   // Check that bytes are *not* unencrypted
1601   char c2;
1602   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1603   EXPECT_NE('a', c2);
1604
1605
1606   base.loop();
1607
1608   EXPECT_EQ(2, readCallback.buffers.size());
1609   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1610 }
1611
1612 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1613   // Start listening on a local port
1614   WriteCallbackBase writeCallback;
1615   WriteErrorCallback readCallback(&writeCallback);
1616   HandshakeCallback handshakeCallback(&readCallback,
1617                                       HandshakeCallback::EXPECT_ERROR);
1618   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1619   TestSSLServer server(&acceptCallback);
1620
1621   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1622   socket->open();
1623   uint8_t buf[3] = {0x16, 0x03, 0x01};
1624   socket->write(buf, sizeof(buf));
1625   socket->closeWithReset();
1626
1627   handshakeCallback.waitForHandshake();
1628   EXPECT_NE(
1629       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1630   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1631 }
1632
1633 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1634   // Start listening on a local port
1635   WriteCallbackBase writeCallback;
1636   WriteErrorCallback readCallback(&writeCallback);
1637   HandshakeCallback handshakeCallback(&readCallback,
1638                                       HandshakeCallback::EXPECT_ERROR);
1639   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1640   TestSSLServer server(&acceptCallback);
1641
1642   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1643   socket->open();
1644   uint8_t buf[3] = {0x16, 0x03, 0x01};
1645   socket->write(buf, sizeof(buf));
1646   socket->close();
1647
1648   handshakeCallback.waitForHandshake();
1649   EXPECT_NE(
1650       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1651   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1652 }
1653
1654 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1655   // Start listening on a local port
1656   WriteCallbackBase writeCallback;
1657   WriteErrorCallback readCallback(&writeCallback);
1658   HandshakeCallback handshakeCallback(&readCallback,
1659                                       HandshakeCallback::EXPECT_ERROR);
1660   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1661   TestSSLServer server(&acceptCallback);
1662
1663   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1664   socket->open();
1665   uint8_t buf[256] = {0x16, 0x03};
1666   memset(buf + 2, 'a', sizeof(buf) - 2);
1667   socket->write(buf, sizeof(buf));
1668   socket->close();
1669
1670   handshakeCallback.waitForHandshake();
1671   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1672             std::string::npos);
1673   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1674             std::string::npos);
1675 }
1676
1677 } // namespace
1678
1679 ///////////////////////////////////////////////////////////////////////////
1680 // init_unit_test_suite
1681 ///////////////////////////////////////////////////////////////////////////
1682 namespace {
1683 struct Initializer {
1684   Initializer() {
1685     signal(SIGPIPE, SIG_IGN);
1686   }
1687 };
1688 Initializer initializer;
1689 } // anonymous