Add additional ALPN mismatch tests.
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/SocketAddress.h>
24
25 #include <folly/io/async/test/BlockingSocket.h>
26
27 #include <fstream>
28 #include <gtest/gtest.h>
29 #include <iostream>
30 #include <list>
31 #include <set>
32 #include <unistd.h>
33 #include <fcntl.h>
34 #include <openssl/bio.h>
35 #include <poll.h>
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 #include <netinet/tcp.h>
39 #include <folly/io/Cursor.h>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 namespace folly {
49 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
50 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
51 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
52
53 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
54 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
55 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
56
57 constexpr size_t SSLClient::kMaxReadBufferSz;
58 constexpr size_t SSLClient::kMaxReadsPerEvent;
59
60 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
61     : ctx_(new folly::SSLContext),
62       acb_(acb),
63       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
64   // Set up the SSL context
65   ctx_->loadCertificate(testCert);
66   ctx_->loadPrivateKey(testKey);
67   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
68
69   acb_->ctx_ = ctx_;
70   acb_->base_ = &evb_;
71
72   //set up the listening socket
73   socket_->bind(0);
74   socket_->getAddress(&address_);
75   socket_->listen(100);
76   socket_->addAcceptCallback(acb_, &evb_);
77   socket_->startAccepting();
78
79   int ret = pthread_create(&thread_, nullptr, Main, this);
80   assert(ret == 0);
81   (void)ret;
82
83   std::cerr << "Accepting connections on " << address_ << std::endl;
84 }
85
86 void getfds(int fds[2]) {
87   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
88     FAIL() << "failed to create socketpair: " << strerror(errno);
89   }
90   for (int idx = 0; idx < 2; ++idx) {
91     int flags = fcntl(fds[idx], F_GETFL, 0);
92     if (flags == -1) {
93       FAIL() << "failed to get flags for socket " << idx << ": "
94              << strerror(errno);
95     }
96     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
97       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
98              << strerror(errno);
99     }
100   }
101 }
102
103 void getctx(
104   std::shared_ptr<folly::SSLContext> clientCtx,
105   std::shared_ptr<folly::SSLContext> serverCtx) {
106   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
107
108   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
109   serverCtx->loadCertificate(
110       testCert);
111   serverCtx->loadPrivateKey(
112       testKey);
113 }
114
115 void sslsocketpair(
116   EventBase* eventBase,
117   AsyncSSLSocket::UniquePtr* clientSock,
118   AsyncSSLSocket::UniquePtr* serverSock) {
119   auto clientCtx = std::make_shared<folly::SSLContext>();
120   auto serverCtx = std::make_shared<folly::SSLContext>();
121   int fds[2];
122   getfds(fds);
123   getctx(clientCtx, serverCtx);
124   clientSock->reset(new AsyncSSLSocket(
125                       clientCtx, eventBase, fds[0], false));
126   serverSock->reset(new AsyncSSLSocket(
127                       serverCtx, eventBase, fds[1], true));
128
129   // (*clientSock)->setSendTimeout(100);
130   // (*serverSock)->setSendTimeout(100);
131 }
132
133 // client protocol filters
134 bool clientProtoFilterPickPony(unsigned char** client,
135   unsigned int* client_len, const unsigned char*, unsigned int ) {
136   //the protocol string in length prefixed byte string. the
137   //length byte is not included in the length
138   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
139   *client = p;
140   *client_len = 7;
141   return true;
142 }
143
144 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
145   const unsigned char*, unsigned int) {
146   return false;
147 }
148
149 std::string getFileAsBuf(const char* fileName) {
150   std::string buffer;
151   folly::readFile(fileName, buffer);
152   return buffer;
153 }
154
155 std::string getCommonName(X509* cert) {
156   X509_NAME* subject = X509_get_subject_name(cert);
157   std::string cn;
158   cn.resize(ub_common_name);
159   X509_NAME_get_text_by_NID(
160       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
161   return cn;
162 }
163
164 /**
165  * Test connecting to, writing to, reading from, and closing the
166  * connection to the SSL server.
167  */
168 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
169   // Start listening on a local port
170   WriteCallbackBase writeCallback;
171   ReadCallback readCallback(&writeCallback);
172   HandshakeCallback handshakeCallback(&readCallback);
173   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
174   TestSSLServer server(&acceptCallback);
175
176   // Set up SSL context.
177   std::shared_ptr<SSLContext> sslContext(new SSLContext());
178   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
179   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
180   //sslContext->authenticate(true, false);
181
182   // connect
183   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
184                                                  sslContext);
185   socket->open();
186
187   // write()
188   uint8_t buf[128];
189   memset(buf, 'a', sizeof(buf));
190   socket->write(buf, sizeof(buf));
191
192   // read()
193   uint8_t readbuf[128];
194   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
195   EXPECT_EQ(bytesRead, 128);
196   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
197
198   // close()
199   socket->close();
200
201   cerr << "ConnectWriteReadClose test completed" << endl;
202 }
203
204 /**
205  * Test reading after server close.
206  */
207 TEST(AsyncSSLSocketTest, ReadAfterClose) {
208   // Start listening on a local port
209   WriteCallbackBase writeCallback;
210   ReadEOFCallback readCallback(&writeCallback);
211   HandshakeCallback handshakeCallback(&readCallback);
212   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
213   auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
214
215   // Set up SSL context.
216   auto sslContext = std::make_shared<SSLContext>();
217   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
218
219   auto socket =
220       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
221   socket->open();
222
223   // This should trigger an EOF on the client.
224   auto evb = handshakeCallback.getSocket()->getEventBase();
225   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
226   std::array<uint8_t, 128> readbuf;
227   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
228   EXPECT_EQ(0, bytesRead);
229 }
230
231 /**
232  * Test bad renegotiation
233  */
234 TEST(AsyncSSLSocketTest, Renegotiate) {
235   EventBase eventBase;
236   auto clientCtx = std::make_shared<SSLContext>();
237   auto dfServerCtx = std::make_shared<SSLContext>();
238   std::array<int, 2> fds;
239   getfds(fds.data());
240   getctx(clientCtx, dfServerCtx);
241
242   AsyncSSLSocket::UniquePtr clientSock(
243       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
244   AsyncSSLSocket::UniquePtr serverSock(
245       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
246   SSLHandshakeClient client(std::move(clientSock), true, true);
247   RenegotiatingServer server(std::move(serverSock));
248
249   while (!client.handshakeSuccess_ && !client.handshakeError_) {
250     eventBase.loopOnce();
251   }
252
253   ASSERT_TRUE(client.handshakeSuccess_);
254
255   auto sslSock = std::move(client).moveSocket();
256   sslSock->detachEventBase();
257   // This is nasty, however we don't want to add support for
258   // renegotiation in AsyncSSLSocket.
259   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
260
261   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
262
263   std::thread t([&]() { eventBase.loopForever(); });
264
265   // Trigger the renegotiation.
266   std::array<uint8_t, 128> buf;
267   memset(buf.data(), 'a', buf.size());
268   try {
269     socket->write(buf.data(), buf.size());
270   } catch (AsyncSocketException& e) {
271     LOG(INFO) << "client got error " << e.what();
272   }
273   eventBase.terminateLoopSoon();
274   t.join();
275
276   eventBase.loop();
277   ASSERT_TRUE(server.renegotiationError_);
278 }
279
280 /**
281  * Negative test for handshakeError().
282  */
283 TEST(AsyncSSLSocketTest, HandshakeError) {
284   // Start listening on a local port
285   WriteCallbackBase writeCallback;
286   WriteErrorCallback readCallback(&writeCallback);
287   HandshakeCallback handshakeCallback(&readCallback);
288   HandshakeErrorCallback acceptCallback(&handshakeCallback);
289   TestSSLServer server(&acceptCallback);
290
291   // Set up SSL context.
292   std::shared_ptr<SSLContext> sslContext(new SSLContext());
293   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
294
295   // connect
296   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
297                                                  sslContext);
298   // read()
299   bool ex = false;
300   try {
301     socket->open();
302
303     uint8_t readbuf[128];
304     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
305     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
306   } catch (AsyncSocketException &e) {
307     ex = true;
308   }
309   EXPECT_TRUE(ex);
310
311   // close()
312   socket->close();
313   cerr << "HandshakeError test completed" << endl;
314 }
315
316 /**
317  * Negative test for readError().
318  */
319 TEST(AsyncSSLSocketTest, ReadError) {
320   // Start listening on a local port
321   WriteCallbackBase writeCallback;
322   ReadErrorCallback readCallback(&writeCallback);
323   HandshakeCallback handshakeCallback(&readCallback);
324   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
325   TestSSLServer server(&acceptCallback);
326
327   // Set up SSL context.
328   std::shared_ptr<SSLContext> sslContext(new SSLContext());
329   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
330
331   // connect
332   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
333                                                  sslContext);
334   socket->open();
335
336   // write something to trigger ssl handshake
337   uint8_t buf[128];
338   memset(buf, 'a', sizeof(buf));
339   socket->write(buf, sizeof(buf));
340
341   socket->close();
342   cerr << "ReadError test completed" << endl;
343 }
344
345 /**
346  * Negative test for writeError().
347  */
348 TEST(AsyncSSLSocketTest, WriteError) {
349   // Start listening on a local port
350   WriteCallbackBase writeCallback;
351   WriteErrorCallback readCallback(&writeCallback);
352   HandshakeCallback handshakeCallback(&readCallback);
353   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
354   TestSSLServer server(&acceptCallback);
355
356   // Set up SSL context.
357   std::shared_ptr<SSLContext> sslContext(new SSLContext());
358   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
359
360   // connect
361   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
362                                                  sslContext);
363   socket->open();
364
365   // write something to trigger ssl handshake
366   uint8_t buf[128];
367   memset(buf, 'a', sizeof(buf));
368   socket->write(buf, sizeof(buf));
369
370   socket->close();
371   cerr << "WriteError test completed" << endl;
372 }
373
374 /**
375  * Test a socket with TCP_NODELAY unset.
376  */
377 TEST(AsyncSSLSocketTest, SocketWithDelay) {
378   // Start listening on a local port
379   WriteCallbackBase writeCallback;
380   ReadCallback readCallback(&writeCallback);
381   HandshakeCallback handshakeCallback(&readCallback);
382   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
383   TestSSLServer server(&acceptCallback);
384
385   // Set up SSL context.
386   std::shared_ptr<SSLContext> sslContext(new SSLContext());
387   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
388
389   // connect
390   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
391                                                  sslContext);
392   socket->open();
393
394   // write()
395   uint8_t buf[128];
396   memset(buf, 'a', sizeof(buf));
397   socket->write(buf, sizeof(buf));
398
399   // read()
400   uint8_t readbuf[128];
401   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
402   EXPECT_EQ(bytesRead, 128);
403   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
404
405   // close()
406   socket->close();
407
408   cerr << "SocketWithDelay test completed" << endl;
409 }
410
411 using NextProtocolTypePair =
412     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
413
414 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
415   // For matching protos
416  public:
417   void SetUp() override { getctx(clientCtx, serverCtx); }
418
419   void connect(bool unset = false) {
420     getfds(fds);
421
422     if (unset) {
423       // unsetting NPN for any of [client, server] is enough to make NPN not
424       // work
425       clientCtx->unsetNextProtocols();
426     }
427
428     AsyncSSLSocket::UniquePtr clientSock(
429       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
430     AsyncSSLSocket::UniquePtr serverSock(
431       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
432     client = folly::make_unique<NpnClient>(std::move(clientSock));
433     server = folly::make_unique<NpnServer>(std::move(serverSock));
434
435     eventBase.loop();
436   }
437
438   void expectProtocol(const std::string& proto) {
439     EXPECT_NE(client->nextProtoLength, 0);
440     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
441     EXPECT_EQ(
442         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
443         0);
444     string selected((const char*)client->nextProto, client->nextProtoLength);
445     EXPECT_EQ(proto, selected);
446   }
447
448   void expectNoProtocol() {
449     EXPECT_EQ(client->nextProtoLength, 0);
450     EXPECT_EQ(server->nextProtoLength, 0);
451     EXPECT_EQ(client->nextProto, nullptr);
452     EXPECT_EQ(server->nextProto, nullptr);
453   }
454
455   void expectProtocolType() {
456     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
457         GetParam().second == SSLContext::NextProtocolType::ANY) {
458       EXPECT_EQ(client->protocolType, server->protocolType);
459     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
460                GetParam().second == SSLContext::NextProtocolType::ANY) {
461       // Well not much we can say
462     } else {
463       expectProtocolType(GetParam());
464     }
465   }
466
467   void expectProtocolType(NextProtocolTypePair expected) {
468     EXPECT_EQ(client->protocolType, expected.first);
469     EXPECT_EQ(server->protocolType, expected.second);
470   }
471
472   EventBase eventBase;
473   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
474   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
475   int fds[2];
476   std::unique_ptr<NpnClient> client;
477   std::unique_ptr<NpnServer> server;
478 };
479
480 class NextProtocolNPNOnlyTest : public NextProtocolTest {
481   // For mismatching protos
482 };
483
484 class NextProtocolMismatchTest : public NextProtocolTest {
485   // For mismatching protos
486 };
487
488 TEST_P(NextProtocolTest, NpnTestOverlap) {
489   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
490   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
491                                         GetParam().second);
492
493   connect();
494
495   expectProtocol("baz");
496   expectProtocolType();
497 }
498
499 TEST_P(NextProtocolTest, NpnTestUnset) {
500   // Identical to above test, except that we want unset NPN before
501   // looping.
502   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
503   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
504                                         GetParam().second);
505
506   connect(true /* unset */);
507
508   // if alpn negotiation fails, type will appear as npn
509   expectNoProtocol();
510   EXPECT_EQ(client->protocolType, server->protocolType);
511 }
512
513 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
514   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
515   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
516                                         GetParam().second);
517
518   connect();
519
520   expectNoProtocol();
521   expectProtocolType(
522       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
523 }
524
525 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
526 // will fail on 1.0.2 before that.
527 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
528   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
529   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
530                                         GetParam().second);
531
532   connect();
533
534   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
535       GetParam().second == SSLContext::NextProtocolType::ALPN) {
536     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
537     // mismatch should result in a fatal alert, but this is OpenSSL's current
538     // behavior and we want to know if it changes.
539     expectNoProtocol();
540   } else {
541     expectProtocol("blub");
542     expectProtocolType(
543         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
544   }
545 }
546
547 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
548   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
549   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
550   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
551                                         GetParam().second);
552
553   connect();
554
555   expectProtocol("ponies");
556   expectProtocolType();
557 }
558
559 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
560   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
561   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
562   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
563                                         GetParam().second);
564
565   connect();
566
567   expectProtocol("blub");
568   expectProtocolType();
569 }
570
571 TEST_P(NextProtocolTest, RandomizedNpnTest) {
572   // Probability that this test will fail is 2^-64, which could be considered
573   // as negligible.
574   const int kTries = 64;
575
576   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
577                                         GetParam().first);
578   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
579                                                   GetParam().second);
580
581   std::set<string> selectedProtocols;
582   for (int i = 0; i < kTries; ++i) {
583     connect();
584
585     EXPECT_NE(client->nextProtoLength, 0);
586     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
587     EXPECT_EQ(
588         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
589         0);
590     string selected((const char*)client->nextProto, client->nextProtoLength);
591     selectedProtocols.insert(selected);
592     expectProtocolType();
593   }
594   EXPECT_EQ(selectedProtocols.size(), 2);
595 }
596
597 INSTANTIATE_TEST_CASE_P(
598     AsyncSSLSocketTest,
599     NextProtocolTest,
600     ::testing::Values(
601         NextProtocolTypePair(
602             SSLContext::NextProtocolType::NPN,
603             SSLContext::NextProtocolType::NPN),
604 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
605         NextProtocolTypePair(
606             SSLContext::NextProtocolType::ALPN,
607             SSLContext::NextProtocolType::ALPN),
608         NextProtocolTypePair(
609             SSLContext::NextProtocolType::ALPN,
610             SSLContext::NextProtocolType::ANY),
611         NextProtocolTypePair(
612             SSLContext::NextProtocolType::ANY,
613             SSLContext::NextProtocolType::ALPN),
614 #endif
615         NextProtocolTypePair(
616             SSLContext::NextProtocolType::NPN,
617             SSLContext::NextProtocolType::ANY),
618         NextProtocolTypePair(
619             SSLContext::NextProtocolType::ANY,
620             SSLContext::NextProtocolType::ANY)));
621
622 INSTANTIATE_TEST_CASE_P(
623     AsyncSSLSocketTest,
624     NextProtocolNPNOnlyTest,
625     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
626                                            SSLContext::NextProtocolType::NPN)));
627
628 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
629 INSTANTIATE_TEST_CASE_P(
630     AsyncSSLSocketTest,
631     NextProtocolMismatchTest,
632     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
633                                            SSLContext::NextProtocolType::ALPN),
634                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
635                                            SSLContext::NextProtocolType::NPN)));
636 #endif
637
638 #ifndef OPENSSL_NO_TLSEXT
639 /**
640  * 1. Client sends TLSEXT_HOSTNAME in client hello.
641  * 2. Server found a match SSL_CTX and use this SSL_CTX to
642  *    continue the SSL handshake.
643  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
644  */
645 TEST(AsyncSSLSocketTest, SNITestMatch) {
646   EventBase eventBase;
647   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
648   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
649   // Use the same SSLContext to continue the handshake after
650   // tlsext_hostname match.
651   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
652   const std::string serverName("xyz.newdev.facebook.com");
653   int fds[2];
654   getfds(fds);
655   getctx(clientCtx, dfServerCtx);
656
657   AsyncSSLSocket::UniquePtr clientSock(
658     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
659   AsyncSSLSocket::UniquePtr serverSock(
660     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
661   SNIClient client(std::move(clientSock));
662   SNIServer server(std::move(serverSock),
663                    dfServerCtx,
664                    hskServerCtx,
665                    serverName);
666
667   eventBase.loop();
668
669   EXPECT_TRUE(client.serverNameMatch);
670   EXPECT_TRUE(server.serverNameMatch);
671 }
672
673 /**
674  * 1. Client sends TLSEXT_HOSTNAME in client hello.
675  * 2. Server cannot find a matching SSL_CTX and continue to use
676  *    the current SSL_CTX to do the handshake.
677  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
678  */
679 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
680   EventBase eventBase;
681   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
682   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
683   // Use the same SSLContext to continue the handshake after
684   // tlsext_hostname match.
685   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
686   const std::string clientRequestingServerName("foo.com");
687   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
688
689   int fds[2];
690   getfds(fds);
691   getctx(clientCtx, dfServerCtx);
692
693   AsyncSSLSocket::UniquePtr clientSock(
694     new AsyncSSLSocket(clientCtx,
695                         &eventBase,
696                         fds[0],
697                         clientRequestingServerName));
698   AsyncSSLSocket::UniquePtr serverSock(
699     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
700   SNIClient client(std::move(clientSock));
701   SNIServer server(std::move(serverSock),
702                    dfServerCtx,
703                    hskServerCtx,
704                    serverExpectedServerName);
705
706   eventBase.loop();
707
708   EXPECT_TRUE(!client.serverNameMatch);
709   EXPECT_TRUE(!server.serverNameMatch);
710 }
711 /**
712  * 1. Client sends TLSEXT_HOSTNAME in client hello.
713  * 2. We then change the serverName.
714  * 3. We expect that we get 'false' as the result for serNameMatch.
715  */
716
717 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
718    EventBase eventBase;
719   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
720   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
721   // Use the same SSLContext to continue the handshake after
722   // tlsext_hostname match.
723   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
724   const std::string serverName("xyz.newdev.facebook.com");
725   int fds[2];
726   getfds(fds);
727   getctx(clientCtx, dfServerCtx);
728
729   AsyncSSLSocket::UniquePtr clientSock(
730     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
731   //Change the server name
732   std::string newName("new.com");
733   clientSock->setServerName(newName);
734   AsyncSSLSocket::UniquePtr serverSock(
735     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
736   SNIClient client(std::move(clientSock));
737   SNIServer server(std::move(serverSock),
738                    dfServerCtx,
739                    hskServerCtx,
740                    serverName);
741
742   eventBase.loop();
743
744   EXPECT_TRUE(!client.serverNameMatch);
745 }
746
747 /**
748  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
749  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
750  */
751 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
752   EventBase eventBase;
753   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
754   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
755   // Use the same SSLContext to continue the handshake after
756   // tlsext_hostname match.
757   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
758   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
759
760   int fds[2];
761   getfds(fds);
762   getctx(clientCtx, dfServerCtx);
763
764   AsyncSSLSocket::UniquePtr clientSock(
765     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
766   AsyncSSLSocket::UniquePtr serverSock(
767     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
768   SNIClient client(std::move(clientSock));
769   SNIServer server(std::move(serverSock),
770                    dfServerCtx,
771                    hskServerCtx,
772                    serverExpectedServerName);
773
774   eventBase.loop();
775
776   EXPECT_TRUE(!client.serverNameMatch);
777   EXPECT_TRUE(!server.serverNameMatch);
778 }
779
780 #endif
781 /**
782  * Test SSL client socket
783  */
784 TEST(AsyncSSLSocketTest, SSLClientTest) {
785   // Start listening on a local port
786   WriteCallbackBase writeCallback;
787   ReadCallback readCallback(&writeCallback);
788   HandshakeCallback handshakeCallback(&readCallback);
789   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
790   TestSSLServer server(&acceptCallback);
791
792   // Set up SSL client
793   EventBase eventBase;
794   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
795
796   client->connect();
797   EventBaseAborter eba(&eventBase, 3000);
798   eventBase.loop();
799
800   EXPECT_EQ(client->getMiss(), 1);
801   EXPECT_EQ(client->getHit(), 0);
802
803   cerr << "SSLClientTest test completed" << endl;
804 }
805
806
807 /**
808  * Test SSL client socket session re-use
809  */
810 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
811   // Start listening on a local port
812   WriteCallbackBase writeCallback;
813   ReadCallback readCallback(&writeCallback);
814   HandshakeCallback handshakeCallback(&readCallback);
815   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
816   TestSSLServer server(&acceptCallback);
817
818   // Set up SSL client
819   EventBase eventBase;
820   auto client =
821       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
822
823   client->connect();
824   EventBaseAborter eba(&eventBase, 3000);
825   eventBase.loop();
826
827   EXPECT_EQ(client->getMiss(), 1);
828   EXPECT_EQ(client->getHit(), 9);
829
830   cerr << "SSLClientTestReuse test completed" << endl;
831 }
832
833 /**
834  * Test SSL client socket timeout
835  */
836 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
837   // Start listening on a local port
838   EmptyReadCallback readCallback;
839   HandshakeCallback handshakeCallback(&readCallback,
840                                       HandshakeCallback::EXPECT_ERROR);
841   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
842   TestSSLServer server(&acceptCallback);
843
844   // Set up SSL client
845   EventBase eventBase;
846   auto client =
847       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
848   client->connect(true /* write before connect completes */);
849   EventBaseAborter eba(&eventBase, 3000);
850   eventBase.loop();
851
852   usleep(100000);
853   // This is checking that the connectError callback precedes any queued
854   // writeError callbacks.  This matches AsyncSocket's behavior
855   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
856   EXPECT_EQ(client->getErrors(), 1);
857   EXPECT_EQ(client->getMiss(), 0);
858   EXPECT_EQ(client->getHit(), 0);
859
860   cerr << "SSLClientTimeoutTest test completed" << endl;
861 }
862
863
864 /**
865  * Test SSL server async cache
866  */
867 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
868   // Start listening on a local port
869   WriteCallbackBase writeCallback;
870   ReadCallback readCallback(&writeCallback);
871   HandshakeCallback handshakeCallback(&readCallback);
872   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
873   TestSSLAsyncCacheServer server(&acceptCallback);
874
875   // Set up SSL client
876   EventBase eventBase;
877   auto client =
878       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
879
880   client->connect();
881   EventBaseAborter eba(&eventBase, 3000);
882   eventBase.loop();
883
884   EXPECT_EQ(server.getAsyncCallbacks(), 18);
885   EXPECT_EQ(server.getAsyncLookups(), 9);
886   EXPECT_EQ(client->getMiss(), 10);
887   EXPECT_EQ(client->getHit(), 0);
888
889   cerr << "SSLServerAsyncCacheTest test completed" << endl;
890 }
891
892
893 /**
894  * Test SSL server accept timeout with cache path
895  */
896 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
897   // Start listening on a local port
898   WriteCallbackBase writeCallback;
899   ReadCallback readCallback(&writeCallback);
900   EmptyReadCallback clientReadCallback;
901   HandshakeCallback handshakeCallback(&readCallback);
902   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
903   TestSSLAsyncCacheServer server(&acceptCallback);
904
905   // Set up SSL client
906   EventBase eventBase;
907   // only do a TCP connect
908   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
909   sock->connect(nullptr, server.getAddress());
910   clientReadCallback.tcpSocket_ = sock;
911   sock->setReadCB(&clientReadCallback);
912
913   EventBaseAborter eba(&eventBase, 3000);
914   eventBase.loop();
915
916   EXPECT_EQ(readCallback.state, STATE_WAITING);
917
918   cerr << "SSLServerTimeoutTest test completed" << endl;
919 }
920
921 /**
922  * Test SSL server accept timeout with cache path
923  */
924 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
925   // Start listening on a local port
926   WriteCallbackBase writeCallback;
927   ReadCallback readCallback(&writeCallback);
928   HandshakeCallback handshakeCallback(&readCallback);
929   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
930   TestSSLAsyncCacheServer server(&acceptCallback);
931
932   // Set up SSL client
933   EventBase eventBase;
934   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
935
936   client->connect();
937   EventBaseAborter eba(&eventBase, 3000);
938   eventBase.loop();
939
940   EXPECT_EQ(server.getAsyncCallbacks(), 1);
941   EXPECT_EQ(server.getAsyncLookups(), 1);
942   EXPECT_EQ(client->getErrors(), 1);
943   EXPECT_EQ(client->getMiss(), 1);
944   EXPECT_EQ(client->getHit(), 0);
945
946   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
947 }
948
949 /**
950  * Test SSL server accept timeout with cache path
951  */
952 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
953   // Start listening on a local port
954   WriteCallbackBase writeCallback;
955   ReadCallback readCallback(&writeCallback);
956   HandshakeCallback handshakeCallback(&readCallback,
957                                       HandshakeCallback::EXPECT_ERROR);
958   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
959   TestSSLAsyncCacheServer server(&acceptCallback, 500);
960
961   // Set up SSL client
962   EventBase eventBase;
963   auto client =
964       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
965
966   client->connect();
967   EventBaseAborter eba(&eventBase, 3000);
968   eventBase.loop();
969
970   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
971       handshakeCallback.closeSocket();});
972   // give time for the cache lookup to come back and find it closed
973   handshakeCallback.waitForHandshake();
974
975   EXPECT_EQ(server.getAsyncCallbacks(), 1);
976   EXPECT_EQ(server.getAsyncLookups(), 1);
977   EXPECT_EQ(client->getErrors(), 1);
978   EXPECT_EQ(client->getMiss(), 1);
979   EXPECT_EQ(client->getHit(), 0);
980
981   cerr << "SSLServerCacheCloseTest test completed" << endl;
982 }
983
984 /**
985  * Verify Client Ciphers obtained using SSL MSG Callback.
986  */
987 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
988   EventBase eventBase;
989   auto clientCtx = std::make_shared<SSLContext>();
990   auto serverCtx = std::make_shared<SSLContext>();
991   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
992   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
993   serverCtx->loadPrivateKey(testKey);
994   serverCtx->loadCertificate(testCert);
995   serverCtx->loadTrustedCertificates(testCA);
996   serverCtx->loadClientCAList(testCA);
997
998   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
999   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
1000   clientCtx->loadPrivateKey(testKey);
1001   clientCtx->loadCertificate(testCert);
1002   clientCtx->loadTrustedCertificates(testCA);
1003
1004   int fds[2];
1005   getfds(fds);
1006
1007   AsyncSSLSocket::UniquePtr clientSock(
1008       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1009   AsyncSSLSocket::UniquePtr serverSock(
1010       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1011
1012   SSLHandshakeClient client(std::move(clientSock), true, true);
1013   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1014
1015   eventBase.loop();
1016
1017   EXPECT_EQ(server.clientCiphers_,
1018             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
1019   EXPECT_TRUE(client.handshakeVerify_);
1020   EXPECT_TRUE(client.handshakeSuccess_);
1021   EXPECT_TRUE(!client.handshakeError_);
1022   EXPECT_TRUE(server.handshakeVerify_);
1023   EXPECT_TRUE(server.handshakeSuccess_);
1024   EXPECT_TRUE(!server.handshakeError_);
1025 }
1026
1027 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1028   EventBase eventBase;
1029   auto ctx = std::make_shared<SSLContext>();
1030
1031   int fds[2];
1032   getfds(fds);
1033
1034   int bufLen = 42;
1035   uint8_t majorVersion = 18;
1036   uint8_t minorVersion = 25;
1037
1038   // Create callback buf
1039   auto buf = IOBuf::create(bufLen);
1040   buf->append(bufLen);
1041   folly::io::RWPrivateCursor cursor(buf.get());
1042   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1043   cursor.write<uint16_t>(0);
1044   cursor.write<uint8_t>(38);
1045   cursor.write<uint8_t>(majorVersion);
1046   cursor.write<uint8_t>(minorVersion);
1047   cursor.skip(32);
1048   cursor.write<uint32_t>(0);
1049
1050   SSL* ssl = ctx->createSSL();
1051   SCOPE_EXIT { SSL_free(ssl); };
1052   AsyncSSLSocket::UniquePtr sock(
1053       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1054   sock->enableClientHelloParsing();
1055
1056   // Test client hello parsing in one packet
1057   AsyncSSLSocket::clientHelloParsingCallback(
1058       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1059   buf.reset();
1060
1061   auto parsedClientHello = sock->getClientHelloInfo();
1062   EXPECT_TRUE(parsedClientHello != nullptr);
1063   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1064   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1065 }
1066
1067 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1068   EventBase eventBase;
1069   auto ctx = std::make_shared<SSLContext>();
1070
1071   int fds[2];
1072   getfds(fds);
1073
1074   int bufLen = 42;
1075   uint8_t majorVersion = 18;
1076   uint8_t minorVersion = 25;
1077
1078   // Create callback buf
1079   auto buf = IOBuf::create(bufLen);
1080   buf->append(bufLen);
1081   folly::io::RWPrivateCursor cursor(buf.get());
1082   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1083   cursor.write<uint16_t>(0);
1084   cursor.write<uint8_t>(38);
1085   cursor.write<uint8_t>(majorVersion);
1086   cursor.write<uint8_t>(minorVersion);
1087   cursor.skip(32);
1088   cursor.write<uint32_t>(0);
1089
1090   SSL* ssl = ctx->createSSL();
1091   SCOPE_EXIT { SSL_free(ssl); };
1092   AsyncSSLSocket::UniquePtr sock(
1093       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1094   sock->enableClientHelloParsing();
1095
1096   // Test parsing with two packets with first packet size < 3
1097   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1098   AsyncSSLSocket::clientHelloParsingCallback(
1099       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1100       ssl, sock.get());
1101   bufCopy.reset();
1102   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1103   AsyncSSLSocket::clientHelloParsingCallback(
1104       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1105       ssl, sock.get());
1106   bufCopy.reset();
1107
1108   auto parsedClientHello = sock->getClientHelloInfo();
1109   EXPECT_TRUE(parsedClientHello != nullptr);
1110   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1111   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1112 }
1113
1114 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1115   EventBase eventBase;
1116   auto ctx = std::make_shared<SSLContext>();
1117
1118   int fds[2];
1119   getfds(fds);
1120
1121   int bufLen = 42;
1122   uint8_t majorVersion = 18;
1123   uint8_t minorVersion = 25;
1124
1125   // Create callback buf
1126   auto buf = IOBuf::create(bufLen);
1127   buf->append(bufLen);
1128   folly::io::RWPrivateCursor cursor(buf.get());
1129   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1130   cursor.write<uint16_t>(0);
1131   cursor.write<uint8_t>(38);
1132   cursor.write<uint8_t>(majorVersion);
1133   cursor.write<uint8_t>(minorVersion);
1134   cursor.skip(32);
1135   cursor.write<uint32_t>(0);
1136
1137   SSL* ssl = ctx->createSSL();
1138   SCOPE_EXIT { SSL_free(ssl); };
1139   AsyncSSLSocket::UniquePtr sock(
1140       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1141   sock->enableClientHelloParsing();
1142
1143   // Test parsing with multiple small packets
1144   for (uint64_t i = 0; i < buf->length(); i += 3) {
1145     auto bufCopy = folly::IOBuf::copyBuffer(
1146         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1147     AsyncSSLSocket::clientHelloParsingCallback(
1148         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1149         ssl, sock.get());
1150     bufCopy.reset();
1151   }
1152
1153   auto parsedClientHello = sock->getClientHelloInfo();
1154   EXPECT_TRUE(parsedClientHello != nullptr);
1155   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1156   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1157 }
1158
1159 /**
1160  * Verify sucessful behavior of SSL certificate validation.
1161  */
1162 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1163   EventBase eventBase;
1164   auto clientCtx = std::make_shared<SSLContext>();
1165   auto dfServerCtx = std::make_shared<SSLContext>();
1166
1167   int fds[2];
1168   getfds(fds);
1169   getctx(clientCtx, dfServerCtx);
1170
1171   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1172   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1173
1174   AsyncSSLSocket::UniquePtr clientSock(
1175     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1176   AsyncSSLSocket::UniquePtr serverSock(
1177     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1178
1179   SSLHandshakeClient client(std::move(clientSock), true, true);
1180   clientCtx->loadTrustedCertificates(testCA);
1181
1182   SSLHandshakeServer server(std::move(serverSock), true, true);
1183
1184   eventBase.loop();
1185
1186   EXPECT_TRUE(client.handshakeVerify_);
1187   EXPECT_TRUE(client.handshakeSuccess_);
1188   EXPECT_TRUE(!client.handshakeError_);
1189   EXPECT_LE(0, client.handshakeTime.count());
1190   EXPECT_TRUE(!server.handshakeVerify_);
1191   EXPECT_TRUE(server.handshakeSuccess_);
1192   EXPECT_TRUE(!server.handshakeError_);
1193   EXPECT_LE(0, server.handshakeTime.count());
1194 }
1195
1196 /**
1197  * Verify that the client's verification callback is able to fail SSL
1198  * connection establishment.
1199  */
1200 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1201   EventBase eventBase;
1202   auto clientCtx = std::make_shared<SSLContext>();
1203   auto dfServerCtx = std::make_shared<SSLContext>();
1204
1205   int fds[2];
1206   getfds(fds);
1207   getctx(clientCtx, dfServerCtx);
1208
1209   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1210   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1211
1212   AsyncSSLSocket::UniquePtr clientSock(
1213     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1214   AsyncSSLSocket::UniquePtr serverSock(
1215     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1216
1217   SSLHandshakeClient client(std::move(clientSock), true, false);
1218   clientCtx->loadTrustedCertificates(testCA);
1219
1220   SSLHandshakeServer server(std::move(serverSock), true, true);
1221
1222   eventBase.loop();
1223
1224   EXPECT_TRUE(client.handshakeVerify_);
1225   EXPECT_TRUE(!client.handshakeSuccess_);
1226   EXPECT_TRUE(client.handshakeError_);
1227   EXPECT_LE(0, client.handshakeTime.count());
1228   EXPECT_TRUE(!server.handshakeVerify_);
1229   EXPECT_TRUE(!server.handshakeSuccess_);
1230   EXPECT_TRUE(server.handshakeError_);
1231   EXPECT_LE(0, server.handshakeTime.count());
1232 }
1233
1234 /**
1235  * Verify that the options in SSLContext can be overridden in
1236  * sslConnect/Accept.i.e specifying that no validation should be performed
1237  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1238  * the validation callback.
1239  */
1240 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1241   EventBase eventBase;
1242   auto clientCtx = std::make_shared<SSLContext>();
1243   auto dfServerCtx = std::make_shared<SSLContext>();
1244
1245   int fds[2];
1246   getfds(fds);
1247   getctx(clientCtx, dfServerCtx);
1248
1249   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1250   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1251
1252   AsyncSSLSocket::UniquePtr clientSock(
1253     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1254   AsyncSSLSocket::UniquePtr serverSock(
1255     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1256
1257   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1258   clientCtx->loadTrustedCertificates(testCA);
1259
1260   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1261
1262   eventBase.loop();
1263
1264   EXPECT_TRUE(!client.handshakeVerify_);
1265   EXPECT_TRUE(client.handshakeSuccess_);
1266   EXPECT_TRUE(!client.handshakeError_);
1267   EXPECT_LE(0, client.handshakeTime.count());
1268   EXPECT_TRUE(!server.handshakeVerify_);
1269   EXPECT_TRUE(server.handshakeSuccess_);
1270   EXPECT_TRUE(!server.handshakeError_);
1271   EXPECT_LE(0, server.handshakeTime.count());
1272 }
1273
1274 /**
1275  * Verify that the options in SSLContext can be overridden in
1276  * sslConnect/Accept. Enable verification even if context says otherwise.
1277  * Test requireClientCert with client cert
1278  */
1279 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1280   EventBase eventBase;
1281   auto clientCtx = std::make_shared<SSLContext>();
1282   auto serverCtx = std::make_shared<SSLContext>();
1283   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1284   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1285   serverCtx->loadPrivateKey(testKey);
1286   serverCtx->loadCertificate(testCert);
1287   serverCtx->loadTrustedCertificates(testCA);
1288   serverCtx->loadClientCAList(testCA);
1289
1290   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1291   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1292   clientCtx->loadPrivateKey(testKey);
1293   clientCtx->loadCertificate(testCert);
1294   clientCtx->loadTrustedCertificates(testCA);
1295
1296   int fds[2];
1297   getfds(fds);
1298
1299   AsyncSSLSocket::UniquePtr clientSock(
1300       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1301   AsyncSSLSocket::UniquePtr serverSock(
1302       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1303
1304   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1305   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1306
1307   eventBase.loop();
1308
1309   EXPECT_TRUE(client.handshakeVerify_);
1310   EXPECT_TRUE(client.handshakeSuccess_);
1311   EXPECT_FALSE(client.handshakeError_);
1312   EXPECT_LE(0, client.handshakeTime.count());
1313   EXPECT_TRUE(server.handshakeVerify_);
1314   EXPECT_TRUE(server.handshakeSuccess_);
1315   EXPECT_FALSE(server.handshakeError_);
1316   EXPECT_LE(0, server.handshakeTime.count());
1317 }
1318
1319 /**
1320  * Verify that the client's verification callback is able to override
1321  * the preverification failure and allow a successful connection.
1322  */
1323 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1324   EventBase eventBase;
1325   auto clientCtx = std::make_shared<SSLContext>();
1326   auto dfServerCtx = std::make_shared<SSLContext>();
1327
1328   int fds[2];
1329   getfds(fds);
1330   getctx(clientCtx, dfServerCtx);
1331
1332   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1333   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1334
1335   AsyncSSLSocket::UniquePtr clientSock(
1336     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1337   AsyncSSLSocket::UniquePtr serverSock(
1338     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1339
1340   SSLHandshakeClient client(std::move(clientSock), false, true);
1341   SSLHandshakeServer server(std::move(serverSock), true, true);
1342
1343   eventBase.loop();
1344
1345   EXPECT_TRUE(client.handshakeVerify_);
1346   EXPECT_TRUE(client.handshakeSuccess_);
1347   EXPECT_TRUE(!client.handshakeError_);
1348   EXPECT_LE(0, client.handshakeTime.count());
1349   EXPECT_TRUE(!server.handshakeVerify_);
1350   EXPECT_TRUE(server.handshakeSuccess_);
1351   EXPECT_TRUE(!server.handshakeError_);
1352   EXPECT_LE(0, server.handshakeTime.count());
1353 }
1354
1355 /**
1356  * Verify that specifying that no validation should be performed allows an
1357  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1358  * callback.
1359  */
1360 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1361   EventBase eventBase;
1362   auto clientCtx = std::make_shared<SSLContext>();
1363   auto dfServerCtx = std::make_shared<SSLContext>();
1364
1365   int fds[2];
1366   getfds(fds);
1367   getctx(clientCtx, dfServerCtx);
1368
1369   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1370   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1371
1372   AsyncSSLSocket::UniquePtr clientSock(
1373     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1374   AsyncSSLSocket::UniquePtr serverSock(
1375     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1376
1377   SSLHandshakeClient client(std::move(clientSock), false, false);
1378   SSLHandshakeServer server(std::move(serverSock), false, false);
1379
1380   eventBase.loop();
1381
1382   EXPECT_TRUE(!client.handshakeVerify_);
1383   EXPECT_TRUE(client.handshakeSuccess_);
1384   EXPECT_TRUE(!client.handshakeError_);
1385   EXPECT_LE(0, client.handshakeTime.count());
1386   EXPECT_TRUE(!server.handshakeVerify_);
1387   EXPECT_TRUE(server.handshakeSuccess_);
1388   EXPECT_TRUE(!server.handshakeError_);
1389   EXPECT_LE(0, server.handshakeTime.count());
1390 }
1391
1392 /**
1393  * Test requireClientCert with client cert
1394  */
1395 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1396   EventBase eventBase;
1397   auto clientCtx = std::make_shared<SSLContext>();
1398   auto serverCtx = std::make_shared<SSLContext>();
1399   serverCtx->setVerificationOption(
1400       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1401   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1402   serverCtx->loadPrivateKey(testKey);
1403   serverCtx->loadCertificate(testCert);
1404   serverCtx->loadTrustedCertificates(testCA);
1405   serverCtx->loadClientCAList(testCA);
1406
1407   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1408   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1409   clientCtx->loadPrivateKey(testKey);
1410   clientCtx->loadCertificate(testCert);
1411   clientCtx->loadTrustedCertificates(testCA);
1412
1413   int fds[2];
1414   getfds(fds);
1415
1416   AsyncSSLSocket::UniquePtr clientSock(
1417       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1418   AsyncSSLSocket::UniquePtr serverSock(
1419       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1420
1421   SSLHandshakeClient client(std::move(clientSock), true, true);
1422   SSLHandshakeServer server(std::move(serverSock), true, true);
1423
1424   eventBase.loop();
1425
1426   EXPECT_TRUE(client.handshakeVerify_);
1427   EXPECT_TRUE(client.handshakeSuccess_);
1428   EXPECT_FALSE(client.handshakeError_);
1429   EXPECT_LE(0, client.handshakeTime.count());
1430   EXPECT_TRUE(server.handshakeVerify_);
1431   EXPECT_TRUE(server.handshakeSuccess_);
1432   EXPECT_FALSE(server.handshakeError_);
1433   EXPECT_LE(0, server.handshakeTime.count());
1434 }
1435
1436
1437 /**
1438  * Test requireClientCert with no client cert
1439  */
1440 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1441   EventBase eventBase;
1442   auto clientCtx = std::make_shared<SSLContext>();
1443   auto serverCtx = std::make_shared<SSLContext>();
1444   serverCtx->setVerificationOption(
1445       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1446   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1447   serverCtx->loadPrivateKey(testKey);
1448   serverCtx->loadCertificate(testCert);
1449   serverCtx->loadTrustedCertificates(testCA);
1450   serverCtx->loadClientCAList(testCA);
1451   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1452   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1453
1454   int fds[2];
1455   getfds(fds);
1456
1457   AsyncSSLSocket::UniquePtr clientSock(
1458       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1459   AsyncSSLSocket::UniquePtr serverSock(
1460       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1461
1462   SSLHandshakeClient client(std::move(clientSock), false, false);
1463   SSLHandshakeServer server(std::move(serverSock), false, false);
1464
1465   eventBase.loop();
1466
1467   EXPECT_FALSE(server.handshakeVerify_);
1468   EXPECT_FALSE(server.handshakeSuccess_);
1469   EXPECT_TRUE(server.handshakeError_);
1470   EXPECT_LE(0, client.handshakeTime.count());
1471   EXPECT_LE(0, server.handshakeTime.count());
1472 }
1473
1474 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1475   auto cert = getFileAsBuf(testCert);
1476   auto key = getFileAsBuf(testKey);
1477
1478   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1479   BIO_write(certBio.get(), cert.data(), cert.size());
1480   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1481   BIO_write(keyBio.get(), key.data(), key.size());
1482
1483   // Create SSL structs from buffers to get properties
1484   ssl::X509UniquePtr certStruct(
1485       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1486   ssl::EvpPkeyUniquePtr keyStruct(
1487       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1488   certBio = nullptr;
1489   keyBio = nullptr;
1490
1491   auto origCommonName = getCommonName(certStruct.get());
1492   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1493   certStruct = nullptr;
1494   keyStruct = nullptr;
1495
1496   auto ctx = std::make_shared<SSLContext>();
1497   ctx->loadPrivateKeyFromBufferPEM(key);
1498   ctx->loadCertificateFromBufferPEM(cert);
1499   ctx->loadTrustedCertificates(testCA);
1500
1501   ssl::SSLUniquePtr ssl(ctx->createSSL());
1502
1503   auto newCert = SSL_get_certificate(ssl.get());
1504   auto newKey = SSL_get_privatekey(ssl.get());
1505
1506   // Get properties from SSL struct
1507   auto newCommonName = getCommonName(newCert);
1508   auto newKeySize = EVP_PKEY_bits(newKey);
1509
1510   // Check that the key and cert have the expected properties
1511   EXPECT_EQ(origCommonName, newCommonName);
1512   EXPECT_EQ(origKeySize, newKeySize);
1513 }
1514
1515 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1516   EventBase eb;
1517
1518   // Set up SSL context.
1519   auto sslContext = std::make_shared<SSLContext>();
1520   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1521
1522   // create SSL socket
1523   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1524
1525   EXPECT_EQ(1500, socket->getMinWriteSize());
1526
1527   socket->setMinWriteSize(0);
1528   EXPECT_EQ(0, socket->getMinWriteSize());
1529   socket->setMinWriteSize(50000);
1530   EXPECT_EQ(50000, socket->getMinWriteSize());
1531 }
1532
1533 class ReadCallbackTerminator : public ReadCallback {
1534  public:
1535   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1536       : ReadCallback(wcb)
1537       , base_(base) {}
1538
1539   // Do not write data back, terminate the loop.
1540   void readDataAvailable(size_t len) noexcept override {
1541     std::cerr << "readDataAvailable, len " << len << std::endl;
1542
1543     currentBuffer.length = len;
1544
1545     buffers.push_back(currentBuffer);
1546     currentBuffer.reset();
1547     state = STATE_SUCCEEDED;
1548
1549     socket_->setReadCB(nullptr);
1550     base_->terminateLoopSoon();
1551   }
1552  private:
1553   EventBase* base_;
1554 };
1555
1556
1557 /**
1558  * Test a full unencrypted codepath
1559  */
1560 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1561   EventBase base;
1562
1563   auto clientCtx = std::make_shared<folly::SSLContext>();
1564   auto serverCtx = std::make_shared<folly::SSLContext>();
1565   int fds[2];
1566   getfds(fds);
1567   getctx(clientCtx, serverCtx);
1568   auto client = AsyncSSLSocket::newSocket(
1569                   clientCtx, &base, fds[0], false, true);
1570   auto server = AsyncSSLSocket::newSocket(
1571                   serverCtx, &base, fds[1], true, true);
1572
1573   ReadCallbackTerminator readCallback(&base, nullptr);
1574   server->setReadCB(&readCallback);
1575   readCallback.setSocket(server);
1576
1577   uint8_t buf[128];
1578   memset(buf, 'a', sizeof(buf));
1579   client->write(nullptr, buf, sizeof(buf));
1580
1581   // Check that bytes are unencrypted
1582   char c;
1583   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1584   EXPECT_EQ('a', c);
1585
1586   EventBaseAborter eba(&base, 3000);
1587   base.loop();
1588
1589   EXPECT_EQ(1, readCallback.buffers.size());
1590   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1591
1592   server->setReadCB(&readCallback);
1593
1594   // Unencrypted
1595   server->sslAccept(nullptr);
1596   client->sslConn(nullptr);
1597
1598   // Do NOT wait for handshake, writing should be queued and happen after
1599
1600   client->write(nullptr, buf, sizeof(buf));
1601
1602   // Check that bytes are *not* unencrypted
1603   char c2;
1604   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1605   EXPECT_NE('a', c2);
1606
1607
1608   base.loop();
1609
1610   EXPECT_EQ(2, readCallback.buffers.size());
1611   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1612 }
1613
1614 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1615   // Start listening on a local port
1616   WriteCallbackBase writeCallback;
1617   WriteErrorCallback readCallback(&writeCallback);
1618   HandshakeCallback handshakeCallback(&readCallback,
1619                                       HandshakeCallback::EXPECT_ERROR);
1620   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1621   TestSSLServer server(&acceptCallback);
1622
1623   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1624   socket->open();
1625   uint8_t buf[3] = {0x16, 0x03, 0x01};
1626   socket->write(buf, sizeof(buf));
1627   socket->closeWithReset();
1628
1629   handshakeCallback.waitForHandshake();
1630   EXPECT_NE(
1631       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1632   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1633 }
1634
1635 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1636   // Start listening on a local port
1637   WriteCallbackBase writeCallback;
1638   WriteErrorCallback readCallback(&writeCallback);
1639   HandshakeCallback handshakeCallback(&readCallback,
1640                                       HandshakeCallback::EXPECT_ERROR);
1641   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1642   TestSSLServer server(&acceptCallback);
1643
1644   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1645   socket->open();
1646   uint8_t buf[3] = {0x16, 0x03, 0x01};
1647   socket->write(buf, sizeof(buf));
1648   socket->close();
1649
1650   handshakeCallback.waitForHandshake();
1651   EXPECT_NE(
1652       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1653   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1654 }
1655
1656 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1657   // Start listening on a local port
1658   WriteCallbackBase writeCallback;
1659   WriteErrorCallback readCallback(&writeCallback);
1660   HandshakeCallback handshakeCallback(&readCallback,
1661                                       HandshakeCallback::EXPECT_ERROR);
1662   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1663   TestSSLServer server(&acceptCallback);
1664
1665   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1666   socket->open();
1667   uint8_t buf[256] = {0x16, 0x03};
1668   memset(buf + 2, 'a', sizeof(buf) - 2);
1669   socket->write(buf, sizeof(buf));
1670   socket->close();
1671
1672   handshakeCallback.waitForHandshake();
1673   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1674             std::string::npos);
1675   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1676             std::string::npos);
1677 }
1678
1679 } // namespace
1680
1681 ///////////////////////////////////////////////////////////////////////////
1682 // init_unit_test_suite
1683 ///////////////////////////////////////////////////////////////////////////
1684 namespace {
1685 struct Initializer {
1686   Initializer() {
1687     signal(SIGPIPE, SIG_IGN);
1688   }
1689 };
1690 Initializer initializer;
1691 } // anonymous