Stop abusing errno
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/SocketAddress.h>
24
25 #include <folly/io/async/test/BlockingSocket.h>
26
27 #include <fstream>
28 #include <gtest/gtest.h>
29 #include <iostream>
30 #include <list>
31 #include <set>
32 #include <unistd.h>
33 #include <fcntl.h>
34 #include <openssl/bio.h>
35 #include <poll.h>
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 #include <netinet/tcp.h>
39 #include <folly/io/Cursor.h>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 namespace folly {
49 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
50 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
51 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
52
53 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
54 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
55 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
56
57 constexpr size_t SSLClient::kMaxReadBufferSz;
58 constexpr size_t SSLClient::kMaxReadsPerEvent;
59
60 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
61     : ctx_(new folly::SSLContext),
62       acb_(acb),
63       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
64   // Set up the SSL context
65   ctx_->loadCertificate(testCert);
66   ctx_->loadPrivateKey(testKey);
67   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
68
69   acb_->ctx_ = ctx_;
70   acb_->base_ = &evb_;
71
72   //set up the listening socket
73   socket_->bind(0);
74   socket_->getAddress(&address_);
75   socket_->listen(100);
76   socket_->addAcceptCallback(acb_, &evb_);
77   socket_->startAccepting();
78
79   int ret = pthread_create(&thread_, nullptr, Main, this);
80   assert(ret == 0);
81   (void)ret;
82
83   std::cerr << "Accepting connections on " << address_ << std::endl;
84 }
85
86 void getfds(int fds[2]) {
87   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
88     FAIL() << "failed to create socketpair: " << strerror(errno);
89   }
90   for (int idx = 0; idx < 2; ++idx) {
91     int flags = fcntl(fds[idx], F_GETFL, 0);
92     if (flags == -1) {
93       FAIL() << "failed to get flags for socket " << idx << ": "
94              << strerror(errno);
95     }
96     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
97       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
98              << strerror(errno);
99     }
100   }
101 }
102
103 void getctx(
104   std::shared_ptr<folly::SSLContext> clientCtx,
105   std::shared_ptr<folly::SSLContext> serverCtx) {
106   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
107
108   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
109   serverCtx->loadCertificate(
110       testCert);
111   serverCtx->loadPrivateKey(
112       testKey);
113 }
114
115 void sslsocketpair(
116   EventBase* eventBase,
117   AsyncSSLSocket::UniquePtr* clientSock,
118   AsyncSSLSocket::UniquePtr* serverSock) {
119   auto clientCtx = std::make_shared<folly::SSLContext>();
120   auto serverCtx = std::make_shared<folly::SSLContext>();
121   int fds[2];
122   getfds(fds);
123   getctx(clientCtx, serverCtx);
124   clientSock->reset(new AsyncSSLSocket(
125                       clientCtx, eventBase, fds[0], false));
126   serverSock->reset(new AsyncSSLSocket(
127                       serverCtx, eventBase, fds[1], true));
128
129   // (*clientSock)->setSendTimeout(100);
130   // (*serverSock)->setSendTimeout(100);
131 }
132
133 // client protocol filters
134 bool clientProtoFilterPickPony(unsigned char** client,
135   unsigned int* client_len, const unsigned char*, unsigned int ) {
136   //the protocol string in length prefixed byte string. the
137   //length byte is not included in the length
138   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
139   *client = p;
140   *client_len = 7;
141   return true;
142 }
143
144 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
145   const unsigned char*, unsigned int) {
146   return false;
147 }
148
149 std::string getFileAsBuf(const char* fileName) {
150   std::string buffer;
151   folly::readFile(fileName, buffer);
152   return buffer;
153 }
154
155 std::string getCommonName(X509* cert) {
156   X509_NAME* subject = X509_get_subject_name(cert);
157   std::string cn;
158   cn.resize(ub_common_name);
159   X509_NAME_get_text_by_NID(
160       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
161   return cn;
162 }
163
164 /**
165  * Test connecting to, writing to, reading from, and closing the
166  * connection to the SSL server.
167  */
168 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
169   // Start listening on a local port
170   WriteCallbackBase writeCallback;
171   ReadCallback readCallback(&writeCallback);
172   HandshakeCallback handshakeCallback(&readCallback);
173   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
174   TestSSLServer server(&acceptCallback);
175
176   // Set up SSL context.
177   std::shared_ptr<SSLContext> sslContext(new SSLContext());
178   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
179   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
180   //sslContext->authenticate(true, false);
181
182   // connect
183   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
184                                                  sslContext);
185   socket->open();
186
187   // write()
188   uint8_t buf[128];
189   memset(buf, 'a', sizeof(buf));
190   socket->write(buf, sizeof(buf));
191
192   // read()
193   uint8_t readbuf[128];
194   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
195   EXPECT_EQ(bytesRead, 128);
196   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
197
198   // close()
199   socket->close();
200
201   cerr << "ConnectWriteReadClose test completed" << endl;
202 }
203
204 /**
205  * Test reading after server close.
206  */
207 TEST(AsyncSSLSocketTest, ReadAfterClose) {
208   // Start listening on a local port
209   WriteCallbackBase writeCallback;
210   ReadEOFCallback readCallback(&writeCallback);
211   HandshakeCallback handshakeCallback(&readCallback);
212   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
213   auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
214
215   // Set up SSL context.
216   auto sslContext = std::make_shared<SSLContext>();
217   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
218
219   auto socket =
220       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
221   socket->open();
222
223   // This should trigger an EOF on the client.
224   auto evb = handshakeCallback.getSocket()->getEventBase();
225   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
226   std::array<uint8_t, 128> readbuf;
227   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
228   EXPECT_EQ(0, bytesRead);
229 }
230
231 /**
232  * Test bad renegotiation
233  */
234 TEST(AsyncSSLSocketTest, Renegotiate) {
235   EventBase eventBase;
236   auto clientCtx = std::make_shared<SSLContext>();
237   auto dfServerCtx = std::make_shared<SSLContext>();
238   std::array<int, 2> fds;
239   getfds(fds.data());
240   getctx(clientCtx, dfServerCtx);
241
242   AsyncSSLSocket::UniquePtr clientSock(
243       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
244   AsyncSSLSocket::UniquePtr serverSock(
245       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
246   SSLHandshakeClient client(std::move(clientSock), true, true);
247   RenegotiatingServer server(std::move(serverSock));
248
249   while (!client.handshakeSuccess_ && !client.handshakeError_) {
250     eventBase.loopOnce();
251   }
252
253   ASSERT_TRUE(client.handshakeSuccess_);
254
255   auto sslSock = std::move(client).moveSocket();
256   sslSock->detachEventBase();
257   // This is nasty, however we don't want to add support for
258   // renegotiation in AsyncSSLSocket.
259   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
260
261   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
262
263   std::thread t([&]() { eventBase.loopForever(); });
264
265   // Trigger the renegotiation.
266   std::array<uint8_t, 128> buf;
267   memset(buf.data(), 'a', buf.size());
268   try {
269     socket->write(buf.data(), buf.size());
270   } catch (AsyncSocketException& e) {
271     LOG(INFO) << "client got error " << e.what();
272   }
273   eventBase.terminateLoopSoon();
274   t.join();
275
276   eventBase.loop();
277   ASSERT_TRUE(server.renegotiationError_);
278 }
279
280 /**
281  * Negative test for handshakeError().
282  */
283 TEST(AsyncSSLSocketTest, HandshakeError) {
284   // Start listening on a local port
285   WriteCallbackBase writeCallback;
286   WriteErrorCallback readCallback(&writeCallback);
287   HandshakeCallback handshakeCallback(&readCallback);
288   HandshakeErrorCallback acceptCallback(&handshakeCallback);
289   TestSSLServer server(&acceptCallback);
290
291   // Set up SSL context.
292   std::shared_ptr<SSLContext> sslContext(new SSLContext());
293   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
294
295   // connect
296   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
297                                                  sslContext);
298   // read()
299   bool ex = false;
300   try {
301     socket->open();
302
303     uint8_t readbuf[128];
304     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
305     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
306   } catch (AsyncSocketException &e) {
307     ex = true;
308   }
309   EXPECT_TRUE(ex);
310
311   // close()
312   socket->close();
313   cerr << "HandshakeError test completed" << endl;
314 }
315
316 /**
317  * Negative test for readError().
318  */
319 TEST(AsyncSSLSocketTest, ReadError) {
320   // Start listening on a local port
321   WriteCallbackBase writeCallback;
322   ReadErrorCallback readCallback(&writeCallback);
323   HandshakeCallback handshakeCallback(&readCallback);
324   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
325   TestSSLServer server(&acceptCallback);
326
327   // Set up SSL context.
328   std::shared_ptr<SSLContext> sslContext(new SSLContext());
329   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
330
331   // connect
332   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
333                                                  sslContext);
334   socket->open();
335
336   // write something to trigger ssl handshake
337   uint8_t buf[128];
338   memset(buf, 'a', sizeof(buf));
339   socket->write(buf, sizeof(buf));
340
341   socket->close();
342   cerr << "ReadError test completed" << endl;
343 }
344
345 /**
346  * Negative test for writeError().
347  */
348 TEST(AsyncSSLSocketTest, WriteError) {
349   // Start listening on a local port
350   WriteCallbackBase writeCallback;
351   WriteErrorCallback readCallback(&writeCallback);
352   HandshakeCallback handshakeCallback(&readCallback);
353   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
354   TestSSLServer server(&acceptCallback);
355
356   // Set up SSL context.
357   std::shared_ptr<SSLContext> sslContext(new SSLContext());
358   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
359
360   // connect
361   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
362                                                  sslContext);
363   socket->open();
364
365   // write something to trigger ssl handshake
366   uint8_t buf[128];
367   memset(buf, 'a', sizeof(buf));
368   socket->write(buf, sizeof(buf));
369
370   socket->close();
371   cerr << "WriteError test completed" << endl;
372 }
373
374 /**
375  * Test a socket with TCP_NODELAY unset.
376  */
377 TEST(AsyncSSLSocketTest, SocketWithDelay) {
378   // Start listening on a local port
379   WriteCallbackBase writeCallback;
380   ReadCallback readCallback(&writeCallback);
381   HandshakeCallback handshakeCallback(&readCallback);
382   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
383   TestSSLServer server(&acceptCallback);
384
385   // Set up SSL context.
386   std::shared_ptr<SSLContext> sslContext(new SSLContext());
387   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
388
389   // connect
390   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
391                                                  sslContext);
392   socket->open();
393
394   // write()
395   uint8_t buf[128];
396   memset(buf, 'a', sizeof(buf));
397   socket->write(buf, sizeof(buf));
398
399   // read()
400   uint8_t readbuf[128];
401   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
402   EXPECT_EQ(bytesRead, 128);
403   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
404
405   // close()
406   socket->close();
407
408   cerr << "SocketWithDelay test completed" << endl;
409 }
410
411 using NextProtocolTypePair =
412     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
413
414 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
415   // For matching protos
416  public:
417   void SetUp() override { getctx(clientCtx, serverCtx); }
418
419   void connect(bool unset = false) {
420     getfds(fds);
421
422     if (unset) {
423       // unsetting NPN for any of [client, server] is enough to make NPN not
424       // work
425       clientCtx->unsetNextProtocols();
426     }
427
428     AsyncSSLSocket::UniquePtr clientSock(
429       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
430     AsyncSSLSocket::UniquePtr serverSock(
431       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
432     client = folly::make_unique<NpnClient>(std::move(clientSock));
433     server = folly::make_unique<NpnServer>(std::move(serverSock));
434
435     eventBase.loop();
436   }
437
438   void expectProtocol(const std::string& proto) {
439     EXPECT_NE(client->nextProtoLength, 0);
440     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
441     EXPECT_EQ(
442         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
443         0);
444     string selected((const char*)client->nextProto, client->nextProtoLength);
445     EXPECT_EQ(proto, selected);
446   }
447
448   void expectNoProtocol() {
449     EXPECT_EQ(client->nextProtoLength, 0);
450     EXPECT_EQ(server->nextProtoLength, 0);
451     EXPECT_EQ(client->nextProto, nullptr);
452     EXPECT_EQ(server->nextProto, nullptr);
453   }
454
455   void expectProtocolType() {
456     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
457         GetParam().second == SSLContext::NextProtocolType::ANY) {
458       EXPECT_EQ(client->protocolType, server->protocolType);
459     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
460                GetParam().second == SSLContext::NextProtocolType::ANY) {
461       // Well not much we can say
462     } else {
463       expectProtocolType(GetParam());
464     }
465   }
466
467   void expectProtocolType(NextProtocolTypePair expected) {
468     EXPECT_EQ(client->protocolType, expected.first);
469     EXPECT_EQ(server->protocolType, expected.second);
470   }
471
472   EventBase eventBase;
473   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
474   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
475   int fds[2];
476   std::unique_ptr<NpnClient> client;
477   std::unique_ptr<NpnServer> server;
478 };
479
480 class NextProtocolNPNOnlyTest : public NextProtocolTest {
481   // For mismatching protos
482 };
483
484 class NextProtocolMismatchTest : public NextProtocolTest {
485   // For mismatching protos
486 };
487
488 TEST_P(NextProtocolTest, NpnTestOverlap) {
489   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
490   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
491                                         GetParam().second);
492
493   connect();
494
495   expectProtocol("baz");
496   expectProtocolType();
497 }
498
499 TEST_P(NextProtocolTest, NpnTestUnset) {
500   // Identical to above test, except that we want unset NPN before
501   // looping.
502   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
503   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
504                                         GetParam().second);
505
506   connect(true /* unset */);
507
508   // if alpn negotiation fails, type will appear as npn
509   expectNoProtocol();
510   EXPECT_EQ(client->protocolType, server->protocolType);
511 }
512
513 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
514   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
515   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
516                                         GetParam().second);
517
518   connect();
519
520   expectNoProtocol();
521   expectProtocolType(
522       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
523 }
524
525 TEST_P(NextProtocolNPNOnlyTest, NpnTestNoOverlap) {
526   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
527   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
528                                         GetParam().second);
529
530   connect();
531
532   expectProtocol("blub");
533   expectProtocolType();
534 }
535
536 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
537   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
538   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
539   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
540                                         GetParam().second);
541
542   connect();
543
544   expectProtocol("ponies");
545   expectProtocolType();
546 }
547
548 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
549   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
550   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
551   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
552                                         GetParam().second);
553
554   connect();
555
556   expectProtocol("blub");
557   expectProtocolType();
558 }
559
560 TEST_P(NextProtocolTest, RandomizedNpnTest) {
561   // Probability that this test will fail is 2^-64, which could be considered
562   // as negligible.
563   const int kTries = 64;
564
565   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
566                                         GetParam().first);
567   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
568                                                   GetParam().second);
569
570   std::set<string> selectedProtocols;
571   for (int i = 0; i < kTries; ++i) {
572     connect();
573
574     EXPECT_NE(client->nextProtoLength, 0);
575     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
576     EXPECT_EQ(
577         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
578         0);
579     string selected((const char*)client->nextProto, client->nextProtoLength);
580     selectedProtocols.insert(selected);
581     expectProtocolType();
582   }
583   EXPECT_EQ(selectedProtocols.size(), 2);
584 }
585
586 INSTANTIATE_TEST_CASE_P(
587     AsyncSSLSocketTest,
588     NextProtocolTest,
589     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
590                                            SSLContext::NextProtocolType::NPN),
591 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
592                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
593                                            SSLContext::NextProtocolType::ALPN),
594 #endif
595                       NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
596                                            SSLContext::NextProtocolType::ANY),
597 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
598                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
599                                            SSLContext::NextProtocolType::ANY),
600 #endif
601                       NextProtocolTypePair(SSLContext::NextProtocolType::ANY,
602                                            SSLContext::NextProtocolType::ANY)));
603
604 INSTANTIATE_TEST_CASE_P(
605     AsyncSSLSocketTest,
606     NextProtocolNPNOnlyTest,
607     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
608                                            SSLContext::NextProtocolType::NPN)));
609
610 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
611 INSTANTIATE_TEST_CASE_P(
612     AsyncSSLSocketTest,
613     NextProtocolMismatchTest,
614     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
615                                            SSLContext::NextProtocolType::ALPN),
616                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
617                                            SSLContext::NextProtocolType::NPN)));
618 #endif
619
620 #ifndef OPENSSL_NO_TLSEXT
621 /**
622  * 1. Client sends TLSEXT_HOSTNAME in client hello.
623  * 2. Server found a match SSL_CTX and use this SSL_CTX to
624  *    continue the SSL handshake.
625  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
626  */
627 TEST(AsyncSSLSocketTest, SNITestMatch) {
628   EventBase eventBase;
629   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
630   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
631   // Use the same SSLContext to continue the handshake after
632   // tlsext_hostname match.
633   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
634   const std::string serverName("xyz.newdev.facebook.com");
635   int fds[2];
636   getfds(fds);
637   getctx(clientCtx, dfServerCtx);
638
639   AsyncSSLSocket::UniquePtr clientSock(
640     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
641   AsyncSSLSocket::UniquePtr serverSock(
642     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
643   SNIClient client(std::move(clientSock));
644   SNIServer server(std::move(serverSock),
645                    dfServerCtx,
646                    hskServerCtx,
647                    serverName);
648
649   eventBase.loop();
650
651   EXPECT_TRUE(client.serverNameMatch);
652   EXPECT_TRUE(server.serverNameMatch);
653 }
654
655 /**
656  * 1. Client sends TLSEXT_HOSTNAME in client hello.
657  * 2. Server cannot find a matching SSL_CTX and continue to use
658  *    the current SSL_CTX to do the handshake.
659  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
660  */
661 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
662   EventBase eventBase;
663   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
664   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
665   // Use the same SSLContext to continue the handshake after
666   // tlsext_hostname match.
667   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
668   const std::string clientRequestingServerName("foo.com");
669   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
670
671   int fds[2];
672   getfds(fds);
673   getctx(clientCtx, dfServerCtx);
674
675   AsyncSSLSocket::UniquePtr clientSock(
676     new AsyncSSLSocket(clientCtx,
677                         &eventBase,
678                         fds[0],
679                         clientRequestingServerName));
680   AsyncSSLSocket::UniquePtr serverSock(
681     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
682   SNIClient client(std::move(clientSock));
683   SNIServer server(std::move(serverSock),
684                    dfServerCtx,
685                    hskServerCtx,
686                    serverExpectedServerName);
687
688   eventBase.loop();
689
690   EXPECT_TRUE(!client.serverNameMatch);
691   EXPECT_TRUE(!server.serverNameMatch);
692 }
693 /**
694  * 1. Client sends TLSEXT_HOSTNAME in client hello.
695  * 2. We then change the serverName.
696  * 3. We expect that we get 'false' as the result for serNameMatch.
697  */
698
699 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
700    EventBase eventBase;
701   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
702   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
703   // Use the same SSLContext to continue the handshake after
704   // tlsext_hostname match.
705   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
706   const std::string serverName("xyz.newdev.facebook.com");
707   int fds[2];
708   getfds(fds);
709   getctx(clientCtx, dfServerCtx);
710
711   AsyncSSLSocket::UniquePtr clientSock(
712     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
713   //Change the server name
714   std::string newName("new.com");
715   clientSock->setServerName(newName);
716   AsyncSSLSocket::UniquePtr serverSock(
717     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
718   SNIClient client(std::move(clientSock));
719   SNIServer server(std::move(serverSock),
720                    dfServerCtx,
721                    hskServerCtx,
722                    serverName);
723
724   eventBase.loop();
725
726   EXPECT_TRUE(!client.serverNameMatch);
727 }
728
729 /**
730  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
731  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
732  */
733 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
734   EventBase eventBase;
735   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
736   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
737   // Use the same SSLContext to continue the handshake after
738   // tlsext_hostname match.
739   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
740   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
741
742   int fds[2];
743   getfds(fds);
744   getctx(clientCtx, dfServerCtx);
745
746   AsyncSSLSocket::UniquePtr clientSock(
747     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
748   AsyncSSLSocket::UniquePtr serverSock(
749     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
750   SNIClient client(std::move(clientSock));
751   SNIServer server(std::move(serverSock),
752                    dfServerCtx,
753                    hskServerCtx,
754                    serverExpectedServerName);
755
756   eventBase.loop();
757
758   EXPECT_TRUE(!client.serverNameMatch);
759   EXPECT_TRUE(!server.serverNameMatch);
760 }
761
762 #endif
763 /**
764  * Test SSL client socket
765  */
766 TEST(AsyncSSLSocketTest, SSLClientTest) {
767   // Start listening on a local port
768   WriteCallbackBase writeCallback;
769   ReadCallback readCallback(&writeCallback);
770   HandshakeCallback handshakeCallback(&readCallback);
771   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
772   TestSSLServer server(&acceptCallback);
773
774   // Set up SSL client
775   EventBase eventBase;
776   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
777
778   client->connect();
779   EventBaseAborter eba(&eventBase, 3000);
780   eventBase.loop();
781
782   EXPECT_EQ(client->getMiss(), 1);
783   EXPECT_EQ(client->getHit(), 0);
784
785   cerr << "SSLClientTest test completed" << endl;
786 }
787
788
789 /**
790  * Test SSL client socket session re-use
791  */
792 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
793   // Start listening on a local port
794   WriteCallbackBase writeCallback;
795   ReadCallback readCallback(&writeCallback);
796   HandshakeCallback handshakeCallback(&readCallback);
797   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
798   TestSSLServer server(&acceptCallback);
799
800   // Set up SSL client
801   EventBase eventBase;
802   auto client =
803       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
804
805   client->connect();
806   EventBaseAborter eba(&eventBase, 3000);
807   eventBase.loop();
808
809   EXPECT_EQ(client->getMiss(), 1);
810   EXPECT_EQ(client->getHit(), 9);
811
812   cerr << "SSLClientTestReuse test completed" << endl;
813 }
814
815 /**
816  * Test SSL client socket timeout
817  */
818 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
819   // Start listening on a local port
820   EmptyReadCallback readCallback;
821   HandshakeCallback handshakeCallback(&readCallback,
822                                       HandshakeCallback::EXPECT_ERROR);
823   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
824   TestSSLServer server(&acceptCallback);
825
826   // Set up SSL client
827   EventBase eventBase;
828   auto client =
829       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
830   client->connect(true /* write before connect completes */);
831   EventBaseAborter eba(&eventBase, 3000);
832   eventBase.loop();
833
834   usleep(100000);
835   // This is checking that the connectError callback precedes any queued
836   // writeError callbacks.  This matches AsyncSocket's behavior
837   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
838   EXPECT_EQ(client->getErrors(), 1);
839   EXPECT_EQ(client->getMiss(), 0);
840   EXPECT_EQ(client->getHit(), 0);
841
842   cerr << "SSLClientTimeoutTest test completed" << endl;
843 }
844
845
846 /**
847  * Test SSL server async cache
848  */
849 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
850   // Start listening on a local port
851   WriteCallbackBase writeCallback;
852   ReadCallback readCallback(&writeCallback);
853   HandshakeCallback handshakeCallback(&readCallback);
854   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
855   TestSSLAsyncCacheServer server(&acceptCallback);
856
857   // Set up SSL client
858   EventBase eventBase;
859   auto client =
860       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
861
862   client->connect();
863   EventBaseAborter eba(&eventBase, 3000);
864   eventBase.loop();
865
866   EXPECT_EQ(server.getAsyncCallbacks(), 18);
867   EXPECT_EQ(server.getAsyncLookups(), 9);
868   EXPECT_EQ(client->getMiss(), 10);
869   EXPECT_EQ(client->getHit(), 0);
870
871   cerr << "SSLServerAsyncCacheTest test completed" << endl;
872 }
873
874
875 /**
876  * Test SSL server accept timeout with cache path
877  */
878 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
879   // Start listening on a local port
880   WriteCallbackBase writeCallback;
881   ReadCallback readCallback(&writeCallback);
882   EmptyReadCallback clientReadCallback;
883   HandshakeCallback handshakeCallback(&readCallback);
884   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
885   TestSSLAsyncCacheServer server(&acceptCallback);
886
887   // Set up SSL client
888   EventBase eventBase;
889   // only do a TCP connect
890   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
891   sock->connect(nullptr, server.getAddress());
892   clientReadCallback.tcpSocket_ = sock;
893   sock->setReadCB(&clientReadCallback);
894
895   EventBaseAborter eba(&eventBase, 3000);
896   eventBase.loop();
897
898   EXPECT_EQ(readCallback.state, STATE_WAITING);
899
900   cerr << "SSLServerTimeoutTest test completed" << endl;
901 }
902
903 /**
904  * Test SSL server accept timeout with cache path
905  */
906 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
907   // Start listening on a local port
908   WriteCallbackBase writeCallback;
909   ReadCallback readCallback(&writeCallback);
910   HandshakeCallback handshakeCallback(&readCallback);
911   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
912   TestSSLAsyncCacheServer server(&acceptCallback);
913
914   // Set up SSL client
915   EventBase eventBase;
916   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
917
918   client->connect();
919   EventBaseAborter eba(&eventBase, 3000);
920   eventBase.loop();
921
922   EXPECT_EQ(server.getAsyncCallbacks(), 1);
923   EXPECT_EQ(server.getAsyncLookups(), 1);
924   EXPECT_EQ(client->getErrors(), 1);
925   EXPECT_EQ(client->getMiss(), 1);
926   EXPECT_EQ(client->getHit(), 0);
927
928   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
929 }
930
931 /**
932  * Test SSL server accept timeout with cache path
933  */
934 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
935   // Start listening on a local port
936   WriteCallbackBase writeCallback;
937   ReadCallback readCallback(&writeCallback);
938   HandshakeCallback handshakeCallback(&readCallback,
939                                       HandshakeCallback::EXPECT_ERROR);
940   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
941   TestSSLAsyncCacheServer server(&acceptCallback, 500);
942
943   // Set up SSL client
944   EventBase eventBase;
945   auto client =
946       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
947
948   client->connect();
949   EventBaseAborter eba(&eventBase, 3000);
950   eventBase.loop();
951
952   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
953       handshakeCallback.closeSocket();});
954   // give time for the cache lookup to come back and find it closed
955   handshakeCallback.waitForHandshake();
956
957   EXPECT_EQ(server.getAsyncCallbacks(), 1);
958   EXPECT_EQ(server.getAsyncLookups(), 1);
959   EXPECT_EQ(client->getErrors(), 1);
960   EXPECT_EQ(client->getMiss(), 1);
961   EXPECT_EQ(client->getHit(), 0);
962
963   cerr << "SSLServerCacheCloseTest test completed" << endl;
964 }
965
966 /**
967  * Verify Client Ciphers obtained using SSL MSG Callback.
968  */
969 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
970   EventBase eventBase;
971   auto clientCtx = std::make_shared<SSLContext>();
972   auto serverCtx = std::make_shared<SSLContext>();
973   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
974   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
975   serverCtx->loadPrivateKey(testKey);
976   serverCtx->loadCertificate(testCert);
977   serverCtx->loadTrustedCertificates(testCA);
978   serverCtx->loadClientCAList(testCA);
979
980   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
981   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
982   clientCtx->loadPrivateKey(testKey);
983   clientCtx->loadCertificate(testCert);
984   clientCtx->loadTrustedCertificates(testCA);
985
986   int fds[2];
987   getfds(fds);
988
989   AsyncSSLSocket::UniquePtr clientSock(
990       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
991   AsyncSSLSocket::UniquePtr serverSock(
992       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
993
994   SSLHandshakeClient client(std::move(clientSock), true, true);
995   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
996
997   eventBase.loop();
998
999   EXPECT_EQ(server.clientCiphers_,
1000             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
1001   EXPECT_TRUE(client.handshakeVerify_);
1002   EXPECT_TRUE(client.handshakeSuccess_);
1003   EXPECT_TRUE(!client.handshakeError_);
1004   EXPECT_TRUE(server.handshakeVerify_);
1005   EXPECT_TRUE(server.handshakeSuccess_);
1006   EXPECT_TRUE(!server.handshakeError_);
1007 }
1008
1009 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1010   EventBase eventBase;
1011   auto ctx = std::make_shared<SSLContext>();
1012
1013   int fds[2];
1014   getfds(fds);
1015
1016   int bufLen = 42;
1017   uint8_t majorVersion = 18;
1018   uint8_t minorVersion = 25;
1019
1020   // Create callback buf
1021   auto buf = IOBuf::create(bufLen);
1022   buf->append(bufLen);
1023   folly::io::RWPrivateCursor cursor(buf.get());
1024   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1025   cursor.write<uint16_t>(0);
1026   cursor.write<uint8_t>(38);
1027   cursor.write<uint8_t>(majorVersion);
1028   cursor.write<uint8_t>(minorVersion);
1029   cursor.skip(32);
1030   cursor.write<uint32_t>(0);
1031
1032   SSL* ssl = ctx->createSSL();
1033   SCOPE_EXIT { SSL_free(ssl); };
1034   AsyncSSLSocket::UniquePtr sock(
1035       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1036   sock->enableClientHelloParsing();
1037
1038   // Test client hello parsing in one packet
1039   AsyncSSLSocket::clientHelloParsingCallback(
1040       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1041   buf.reset();
1042
1043   auto parsedClientHello = sock->getClientHelloInfo();
1044   EXPECT_TRUE(parsedClientHello != nullptr);
1045   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1046   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1047 }
1048
1049 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1050   EventBase eventBase;
1051   auto ctx = std::make_shared<SSLContext>();
1052
1053   int fds[2];
1054   getfds(fds);
1055
1056   int bufLen = 42;
1057   uint8_t majorVersion = 18;
1058   uint8_t minorVersion = 25;
1059
1060   // Create callback buf
1061   auto buf = IOBuf::create(bufLen);
1062   buf->append(bufLen);
1063   folly::io::RWPrivateCursor cursor(buf.get());
1064   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1065   cursor.write<uint16_t>(0);
1066   cursor.write<uint8_t>(38);
1067   cursor.write<uint8_t>(majorVersion);
1068   cursor.write<uint8_t>(minorVersion);
1069   cursor.skip(32);
1070   cursor.write<uint32_t>(0);
1071
1072   SSL* ssl = ctx->createSSL();
1073   SCOPE_EXIT { SSL_free(ssl); };
1074   AsyncSSLSocket::UniquePtr sock(
1075       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1076   sock->enableClientHelloParsing();
1077
1078   // Test parsing with two packets with first packet size < 3
1079   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1080   AsyncSSLSocket::clientHelloParsingCallback(
1081       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1082       ssl, sock.get());
1083   bufCopy.reset();
1084   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1085   AsyncSSLSocket::clientHelloParsingCallback(
1086       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1087       ssl, sock.get());
1088   bufCopy.reset();
1089
1090   auto parsedClientHello = sock->getClientHelloInfo();
1091   EXPECT_TRUE(parsedClientHello != nullptr);
1092   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1093   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1094 }
1095
1096 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1097   EventBase eventBase;
1098   auto ctx = std::make_shared<SSLContext>();
1099
1100   int fds[2];
1101   getfds(fds);
1102
1103   int bufLen = 42;
1104   uint8_t majorVersion = 18;
1105   uint8_t minorVersion = 25;
1106
1107   // Create callback buf
1108   auto buf = IOBuf::create(bufLen);
1109   buf->append(bufLen);
1110   folly::io::RWPrivateCursor cursor(buf.get());
1111   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1112   cursor.write<uint16_t>(0);
1113   cursor.write<uint8_t>(38);
1114   cursor.write<uint8_t>(majorVersion);
1115   cursor.write<uint8_t>(minorVersion);
1116   cursor.skip(32);
1117   cursor.write<uint32_t>(0);
1118
1119   SSL* ssl = ctx->createSSL();
1120   SCOPE_EXIT { SSL_free(ssl); };
1121   AsyncSSLSocket::UniquePtr sock(
1122       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1123   sock->enableClientHelloParsing();
1124
1125   // Test parsing with multiple small packets
1126   for (uint64_t i = 0; i < buf->length(); i += 3) {
1127     auto bufCopy = folly::IOBuf::copyBuffer(
1128         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1129     AsyncSSLSocket::clientHelloParsingCallback(
1130         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1131         ssl, sock.get());
1132     bufCopy.reset();
1133   }
1134
1135   auto parsedClientHello = sock->getClientHelloInfo();
1136   EXPECT_TRUE(parsedClientHello != nullptr);
1137   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1138   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1139 }
1140
1141 /**
1142  * Verify sucessful behavior of SSL certificate validation.
1143  */
1144 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1145   EventBase eventBase;
1146   auto clientCtx = std::make_shared<SSLContext>();
1147   auto dfServerCtx = std::make_shared<SSLContext>();
1148
1149   int fds[2];
1150   getfds(fds);
1151   getctx(clientCtx, dfServerCtx);
1152
1153   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1154   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1155
1156   AsyncSSLSocket::UniquePtr clientSock(
1157     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1158   AsyncSSLSocket::UniquePtr serverSock(
1159     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1160
1161   SSLHandshakeClient client(std::move(clientSock), true, true);
1162   clientCtx->loadTrustedCertificates(testCA);
1163
1164   SSLHandshakeServer server(std::move(serverSock), true, true);
1165
1166   eventBase.loop();
1167
1168   EXPECT_TRUE(client.handshakeVerify_);
1169   EXPECT_TRUE(client.handshakeSuccess_);
1170   EXPECT_TRUE(!client.handshakeError_);
1171   EXPECT_LE(0, client.handshakeTime.count());
1172   EXPECT_TRUE(!server.handshakeVerify_);
1173   EXPECT_TRUE(server.handshakeSuccess_);
1174   EXPECT_TRUE(!server.handshakeError_);
1175   EXPECT_LE(0, server.handshakeTime.count());
1176 }
1177
1178 /**
1179  * Verify that the client's verification callback is able to fail SSL
1180  * connection establishment.
1181  */
1182 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1183   EventBase eventBase;
1184   auto clientCtx = std::make_shared<SSLContext>();
1185   auto dfServerCtx = std::make_shared<SSLContext>();
1186
1187   int fds[2];
1188   getfds(fds);
1189   getctx(clientCtx, dfServerCtx);
1190
1191   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1192   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1193
1194   AsyncSSLSocket::UniquePtr clientSock(
1195     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1196   AsyncSSLSocket::UniquePtr serverSock(
1197     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1198
1199   SSLHandshakeClient client(std::move(clientSock), true, false);
1200   clientCtx->loadTrustedCertificates(testCA);
1201
1202   SSLHandshakeServer server(std::move(serverSock), true, true);
1203
1204   eventBase.loop();
1205
1206   EXPECT_TRUE(client.handshakeVerify_);
1207   EXPECT_TRUE(!client.handshakeSuccess_);
1208   EXPECT_TRUE(client.handshakeError_);
1209   EXPECT_LE(0, client.handshakeTime.count());
1210   EXPECT_TRUE(!server.handshakeVerify_);
1211   EXPECT_TRUE(!server.handshakeSuccess_);
1212   EXPECT_TRUE(server.handshakeError_);
1213   EXPECT_LE(0, server.handshakeTime.count());
1214 }
1215
1216 /**
1217  * Verify that the options in SSLContext can be overridden in
1218  * sslConnect/Accept.i.e specifying that no validation should be performed
1219  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1220  * the validation callback.
1221  */
1222 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1223   EventBase eventBase;
1224   auto clientCtx = std::make_shared<SSLContext>();
1225   auto dfServerCtx = std::make_shared<SSLContext>();
1226
1227   int fds[2];
1228   getfds(fds);
1229   getctx(clientCtx, dfServerCtx);
1230
1231   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1232   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1233
1234   AsyncSSLSocket::UniquePtr clientSock(
1235     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1236   AsyncSSLSocket::UniquePtr serverSock(
1237     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1238
1239   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1240   clientCtx->loadTrustedCertificates(testCA);
1241
1242   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1243
1244   eventBase.loop();
1245
1246   EXPECT_TRUE(!client.handshakeVerify_);
1247   EXPECT_TRUE(client.handshakeSuccess_);
1248   EXPECT_TRUE(!client.handshakeError_);
1249   EXPECT_LE(0, client.handshakeTime.count());
1250   EXPECT_TRUE(!server.handshakeVerify_);
1251   EXPECT_TRUE(server.handshakeSuccess_);
1252   EXPECT_TRUE(!server.handshakeError_);
1253   EXPECT_LE(0, server.handshakeTime.count());
1254 }
1255
1256 /**
1257  * Verify that the options in SSLContext can be overridden in
1258  * sslConnect/Accept. Enable verification even if context says otherwise.
1259  * Test requireClientCert with client cert
1260  */
1261 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1262   EventBase eventBase;
1263   auto clientCtx = std::make_shared<SSLContext>();
1264   auto serverCtx = std::make_shared<SSLContext>();
1265   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1266   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1267   serverCtx->loadPrivateKey(testKey);
1268   serverCtx->loadCertificate(testCert);
1269   serverCtx->loadTrustedCertificates(testCA);
1270   serverCtx->loadClientCAList(testCA);
1271
1272   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1273   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1274   clientCtx->loadPrivateKey(testKey);
1275   clientCtx->loadCertificate(testCert);
1276   clientCtx->loadTrustedCertificates(testCA);
1277
1278   int fds[2];
1279   getfds(fds);
1280
1281   AsyncSSLSocket::UniquePtr clientSock(
1282       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1283   AsyncSSLSocket::UniquePtr serverSock(
1284       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1285
1286   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1287   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1288
1289   eventBase.loop();
1290
1291   EXPECT_TRUE(client.handshakeVerify_);
1292   EXPECT_TRUE(client.handshakeSuccess_);
1293   EXPECT_FALSE(client.handshakeError_);
1294   EXPECT_LE(0, client.handshakeTime.count());
1295   EXPECT_TRUE(server.handshakeVerify_);
1296   EXPECT_TRUE(server.handshakeSuccess_);
1297   EXPECT_FALSE(server.handshakeError_);
1298   EXPECT_LE(0, server.handshakeTime.count());
1299 }
1300
1301 /**
1302  * Verify that the client's verification callback is able to override
1303  * the preverification failure and allow a successful connection.
1304  */
1305 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1306   EventBase eventBase;
1307   auto clientCtx = std::make_shared<SSLContext>();
1308   auto dfServerCtx = std::make_shared<SSLContext>();
1309
1310   int fds[2];
1311   getfds(fds);
1312   getctx(clientCtx, dfServerCtx);
1313
1314   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1315   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1316
1317   AsyncSSLSocket::UniquePtr clientSock(
1318     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1319   AsyncSSLSocket::UniquePtr serverSock(
1320     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1321
1322   SSLHandshakeClient client(std::move(clientSock), false, true);
1323   SSLHandshakeServer server(std::move(serverSock), true, true);
1324
1325   eventBase.loop();
1326
1327   EXPECT_TRUE(client.handshakeVerify_);
1328   EXPECT_TRUE(client.handshakeSuccess_);
1329   EXPECT_TRUE(!client.handshakeError_);
1330   EXPECT_LE(0, client.handshakeTime.count());
1331   EXPECT_TRUE(!server.handshakeVerify_);
1332   EXPECT_TRUE(server.handshakeSuccess_);
1333   EXPECT_TRUE(!server.handshakeError_);
1334   EXPECT_LE(0, server.handshakeTime.count());
1335 }
1336
1337 /**
1338  * Verify that specifying that no validation should be performed allows an
1339  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1340  * callback.
1341  */
1342 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1343   EventBase eventBase;
1344   auto clientCtx = std::make_shared<SSLContext>();
1345   auto dfServerCtx = std::make_shared<SSLContext>();
1346
1347   int fds[2];
1348   getfds(fds);
1349   getctx(clientCtx, dfServerCtx);
1350
1351   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1352   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1353
1354   AsyncSSLSocket::UniquePtr clientSock(
1355     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1356   AsyncSSLSocket::UniquePtr serverSock(
1357     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1358
1359   SSLHandshakeClient client(std::move(clientSock), false, false);
1360   SSLHandshakeServer server(std::move(serverSock), false, false);
1361
1362   eventBase.loop();
1363
1364   EXPECT_TRUE(!client.handshakeVerify_);
1365   EXPECT_TRUE(client.handshakeSuccess_);
1366   EXPECT_TRUE(!client.handshakeError_);
1367   EXPECT_LE(0, client.handshakeTime.count());
1368   EXPECT_TRUE(!server.handshakeVerify_);
1369   EXPECT_TRUE(server.handshakeSuccess_);
1370   EXPECT_TRUE(!server.handshakeError_);
1371   EXPECT_LE(0, server.handshakeTime.count());
1372 }
1373
1374 /**
1375  * Test requireClientCert with client cert
1376  */
1377 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1378   EventBase eventBase;
1379   auto clientCtx = std::make_shared<SSLContext>();
1380   auto serverCtx = std::make_shared<SSLContext>();
1381   serverCtx->setVerificationOption(
1382       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1383   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1384   serverCtx->loadPrivateKey(testKey);
1385   serverCtx->loadCertificate(testCert);
1386   serverCtx->loadTrustedCertificates(testCA);
1387   serverCtx->loadClientCAList(testCA);
1388
1389   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1390   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1391   clientCtx->loadPrivateKey(testKey);
1392   clientCtx->loadCertificate(testCert);
1393   clientCtx->loadTrustedCertificates(testCA);
1394
1395   int fds[2];
1396   getfds(fds);
1397
1398   AsyncSSLSocket::UniquePtr clientSock(
1399       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1400   AsyncSSLSocket::UniquePtr serverSock(
1401       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1402
1403   SSLHandshakeClient client(std::move(clientSock), true, true);
1404   SSLHandshakeServer server(std::move(serverSock), true, true);
1405
1406   eventBase.loop();
1407
1408   EXPECT_TRUE(client.handshakeVerify_);
1409   EXPECT_TRUE(client.handshakeSuccess_);
1410   EXPECT_FALSE(client.handshakeError_);
1411   EXPECT_LE(0, client.handshakeTime.count());
1412   EXPECT_TRUE(server.handshakeVerify_);
1413   EXPECT_TRUE(server.handshakeSuccess_);
1414   EXPECT_FALSE(server.handshakeError_);
1415   EXPECT_LE(0, server.handshakeTime.count());
1416 }
1417
1418
1419 /**
1420  * Test requireClientCert with no client cert
1421  */
1422 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1423   EventBase eventBase;
1424   auto clientCtx = std::make_shared<SSLContext>();
1425   auto serverCtx = std::make_shared<SSLContext>();
1426   serverCtx->setVerificationOption(
1427       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1428   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1429   serverCtx->loadPrivateKey(testKey);
1430   serverCtx->loadCertificate(testCert);
1431   serverCtx->loadTrustedCertificates(testCA);
1432   serverCtx->loadClientCAList(testCA);
1433   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1434   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1435
1436   int fds[2];
1437   getfds(fds);
1438
1439   AsyncSSLSocket::UniquePtr clientSock(
1440       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1441   AsyncSSLSocket::UniquePtr serverSock(
1442       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1443
1444   SSLHandshakeClient client(std::move(clientSock), false, false);
1445   SSLHandshakeServer server(std::move(serverSock), false, false);
1446
1447   eventBase.loop();
1448
1449   EXPECT_FALSE(server.handshakeVerify_);
1450   EXPECT_FALSE(server.handshakeSuccess_);
1451   EXPECT_TRUE(server.handshakeError_);
1452   EXPECT_LE(0, client.handshakeTime.count());
1453   EXPECT_LE(0, server.handshakeTime.count());
1454 }
1455
1456 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1457   auto cert = getFileAsBuf(testCert);
1458   auto key = getFileAsBuf(testKey);
1459
1460   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1461   BIO_write(certBio.get(), cert.data(), cert.size());
1462   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1463   BIO_write(keyBio.get(), key.data(), key.size());
1464
1465   // Create SSL structs from buffers to get properties
1466   ssl::X509UniquePtr certStruct(
1467       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1468   ssl::EvpPkeyUniquePtr keyStruct(
1469       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1470   certBio = nullptr;
1471   keyBio = nullptr;
1472
1473   auto origCommonName = getCommonName(certStruct.get());
1474   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1475   certStruct = nullptr;
1476   keyStruct = nullptr;
1477
1478   auto ctx = std::make_shared<SSLContext>();
1479   ctx->loadPrivateKeyFromBufferPEM(key);
1480   ctx->loadCertificateFromBufferPEM(cert);
1481   ctx->loadTrustedCertificates(testCA);
1482
1483   ssl::SSLUniquePtr ssl(ctx->createSSL());
1484
1485   auto newCert = SSL_get_certificate(ssl.get());
1486   auto newKey = SSL_get_privatekey(ssl.get());
1487
1488   // Get properties from SSL struct
1489   auto newCommonName = getCommonName(newCert);
1490   auto newKeySize = EVP_PKEY_bits(newKey);
1491
1492   // Check that the key and cert have the expected properties
1493   EXPECT_EQ(origCommonName, newCommonName);
1494   EXPECT_EQ(origKeySize, newKeySize);
1495 }
1496
1497 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1498   EventBase eb;
1499
1500   // Set up SSL context.
1501   auto sslContext = std::make_shared<SSLContext>();
1502   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1503
1504   // create SSL socket
1505   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1506
1507   EXPECT_EQ(1500, socket->getMinWriteSize());
1508
1509   socket->setMinWriteSize(0);
1510   EXPECT_EQ(0, socket->getMinWriteSize());
1511   socket->setMinWriteSize(50000);
1512   EXPECT_EQ(50000, socket->getMinWriteSize());
1513 }
1514
1515 class ReadCallbackTerminator : public ReadCallback {
1516  public:
1517   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1518       : ReadCallback(wcb)
1519       , base_(base) {}
1520
1521   // Do not write data back, terminate the loop.
1522   void readDataAvailable(size_t len) noexcept override {
1523     std::cerr << "readDataAvailable, len " << len << std::endl;
1524
1525     currentBuffer.length = len;
1526
1527     buffers.push_back(currentBuffer);
1528     currentBuffer.reset();
1529     state = STATE_SUCCEEDED;
1530
1531     socket_->setReadCB(nullptr);
1532     base_->terminateLoopSoon();
1533   }
1534  private:
1535   EventBase* base_;
1536 };
1537
1538
1539 /**
1540  * Test a full unencrypted codepath
1541  */
1542 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1543   EventBase base;
1544
1545   auto clientCtx = std::make_shared<folly::SSLContext>();
1546   auto serverCtx = std::make_shared<folly::SSLContext>();
1547   int fds[2];
1548   getfds(fds);
1549   getctx(clientCtx, serverCtx);
1550   auto client = AsyncSSLSocket::newSocket(
1551                   clientCtx, &base, fds[0], false, true);
1552   auto server = AsyncSSLSocket::newSocket(
1553                   serverCtx, &base, fds[1], true, true);
1554
1555   ReadCallbackTerminator readCallback(&base, nullptr);
1556   server->setReadCB(&readCallback);
1557   readCallback.setSocket(server);
1558
1559   uint8_t buf[128];
1560   memset(buf, 'a', sizeof(buf));
1561   client->write(nullptr, buf, sizeof(buf));
1562
1563   // Check that bytes are unencrypted
1564   char c;
1565   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1566   EXPECT_EQ('a', c);
1567
1568   EventBaseAborter eba(&base, 3000);
1569   base.loop();
1570
1571   EXPECT_EQ(1, readCallback.buffers.size());
1572   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1573
1574   server->setReadCB(&readCallback);
1575
1576   // Unencrypted
1577   server->sslAccept(nullptr);
1578   client->sslConn(nullptr);
1579
1580   // Do NOT wait for handshake, writing should be queued and happen after
1581
1582   client->write(nullptr, buf, sizeof(buf));
1583
1584   // Check that bytes are *not* unencrypted
1585   char c2;
1586   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1587   EXPECT_NE('a', c2);
1588
1589
1590   base.loop();
1591
1592   EXPECT_EQ(2, readCallback.buffers.size());
1593   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1594 }
1595
1596 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1597   // Start listening on a local port
1598   WriteCallbackBase writeCallback;
1599   WriteErrorCallback readCallback(&writeCallback);
1600   HandshakeCallback handshakeCallback(&readCallback,
1601                                       HandshakeCallback::EXPECT_ERROR);
1602   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1603   TestSSLServer server(&acceptCallback);
1604
1605   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1606   socket->open();
1607   uint8_t buf[3] = {0x16, 0x03, 0x01};
1608   socket->write(buf, sizeof(buf));
1609   socket->closeWithReset();
1610
1611   handshakeCallback.waitForHandshake();
1612   EXPECT_NE(handshakeCallback.errorString_.find("SSL_ERROR_SYSCALL"),
1613             std::string::npos);
1614   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1615 }
1616
1617 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1618   // Start listening on a local port
1619   WriteCallbackBase writeCallback;
1620   WriteErrorCallback readCallback(&writeCallback);
1621   HandshakeCallback handshakeCallback(&readCallback,
1622                                       HandshakeCallback::EXPECT_ERROR);
1623   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1624   TestSSLServer server(&acceptCallback);
1625
1626   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1627   socket->open();
1628   uint8_t buf[3] = {0x16, 0x03, 0x01};
1629   socket->write(buf, sizeof(buf));
1630   socket->close();
1631
1632   handshakeCallback.waitForHandshake();
1633   EXPECT_NE(handshakeCallback.errorString_.find("SSL_ERROR_SYSCALL"),
1634             std::string::npos);
1635   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1636 }
1637
1638 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1639   // Start listening on a local port
1640   WriteCallbackBase writeCallback;
1641   WriteErrorCallback readCallback(&writeCallback);
1642   HandshakeCallback handshakeCallback(&readCallback,
1643                                       HandshakeCallback::EXPECT_ERROR);
1644   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1645   TestSSLServer server(&acceptCallback);
1646
1647   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1648   socket->open();
1649   uint8_t buf[256] = {0x16, 0x03};
1650   memset(buf + 2, 'a', sizeof(buf) - 2);
1651   socket->write(buf, sizeof(buf));
1652   socket->close();
1653
1654   handshakeCallback.waitForHandshake();
1655   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1656             std::string::npos);
1657   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1658             std::string::npos);
1659 }
1660
1661 } // namespace
1662
1663 ///////////////////////////////////////////////////////////////////////////
1664 // init_unit_test_suite
1665 ///////////////////////////////////////////////////////////////////////////
1666 namespace {
1667 struct Initializer {
1668   Initializer() {
1669     signal(SIGPIPE, SIG_IGN);
1670   }
1671 };
1672 Initializer initializer;
1673 } // anonymous