Add OpenSSL portability layer
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <pthread.h>
19 #include <signal.h>
20
21 #include <folly/SocketAddress.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/EventBase.h>
24 #include <folly/portability/GMock.h>
25 #include <folly/portability/GTest.h>
26 #include <folly/portability/Sockets.h>
27 #include <folly/portability/Unistd.h>
28
29 #include <folly/io/async/test/BlockingSocket.h>
30
31 #include <fcntl.h>
32 #include <folly/io/Cursor.h>
33 #include <openssl/bio.h>
34 #include <sys/types.h>
35 #include <fstream>
36 #include <iostream>
37 #include <list>
38 #include <set>
39 #include <thread>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 using namespace testing;
49
50 namespace folly {
51 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
52 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
53 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
54
55 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
56 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
57 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
58
59 constexpr size_t SSLClient::kMaxReadBufferSz;
60 constexpr size_t SSLClient::kMaxReadsPerEvent;
61
62 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
63     : ctx_(new folly::SSLContext),
64       acb_(acb),
65       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
66   // Set up the SSL context
67   ctx_->loadCertificate(testCert);
68   ctx_->loadPrivateKey(testKey);
69   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
70
71   acb_->ctx_ = ctx_;
72   acb_->base_ = &evb_;
73
74   // Enable TFO
75   if (enableTFO) {
76     LOG(INFO) << "server TFO enabled";
77     socket_->setTFOEnabled(true, 1000);
78   }
79
80   // set up the listening socket
81   socket_->bind(0);
82   socket_->getAddress(&address_);
83   socket_->listen(100);
84   socket_->addAcceptCallback(acb_, &evb_);
85   socket_->startAccepting();
86
87   int ret = pthread_create(&thread_, nullptr, Main, this);
88   assert(ret == 0);
89   (void)ret;
90
91   std::cerr << "Accepting connections on " << address_ << std::endl;
92 }
93
94 void getfds(int fds[2]) {
95   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
96     FAIL() << "failed to create socketpair: " << strerror(errno);
97   }
98   for (int idx = 0; idx < 2; ++idx) {
99     int flags = fcntl(fds[idx], F_GETFL, 0);
100     if (flags == -1) {
101       FAIL() << "failed to get flags for socket " << idx << ": "
102              << strerror(errno);
103     }
104     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
105       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
106              << strerror(errno);
107     }
108   }
109 }
110
111 void getctx(
112   std::shared_ptr<folly::SSLContext> clientCtx,
113   std::shared_ptr<folly::SSLContext> serverCtx) {
114   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
115
116   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
117   serverCtx->loadCertificate(
118       testCert);
119   serverCtx->loadPrivateKey(
120       testKey);
121 }
122
123 void sslsocketpair(
124   EventBase* eventBase,
125   AsyncSSLSocket::UniquePtr* clientSock,
126   AsyncSSLSocket::UniquePtr* serverSock) {
127   auto clientCtx = std::make_shared<folly::SSLContext>();
128   auto serverCtx = std::make_shared<folly::SSLContext>();
129   int fds[2];
130   getfds(fds);
131   getctx(clientCtx, serverCtx);
132   clientSock->reset(new AsyncSSLSocket(
133                       clientCtx, eventBase, fds[0], false));
134   serverSock->reset(new AsyncSSLSocket(
135                       serverCtx, eventBase, fds[1], true));
136
137   // (*clientSock)->setSendTimeout(100);
138   // (*serverSock)->setSendTimeout(100);
139 }
140
141 // client protocol filters
142 bool clientProtoFilterPickPony(unsigned char** client,
143   unsigned int* client_len, const unsigned char*, unsigned int ) {
144   //the protocol string in length prefixed byte string. the
145   //length byte is not included in the length
146   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
147   *client = p;
148   *client_len = 7;
149   return true;
150 }
151
152 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
153   const unsigned char*, unsigned int) {
154   return false;
155 }
156
157 std::string getFileAsBuf(const char* fileName) {
158   std::string buffer;
159   folly::readFile(fileName, buffer);
160   return buffer;
161 }
162
163 std::string getCommonName(X509* cert) {
164   X509_NAME* subject = X509_get_subject_name(cert);
165   std::string cn;
166   cn.resize(ub_common_name);
167   X509_NAME_get_text_by_NID(
168       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
169   return cn;
170 }
171
172 /**
173  * Test connecting to, writing to, reading from, and closing the
174  * connection to the SSL server.
175  */
176 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
177   // Start listening on a local port
178   WriteCallbackBase writeCallback;
179   ReadCallback readCallback(&writeCallback);
180   HandshakeCallback handshakeCallback(&readCallback);
181   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
182   TestSSLServer server(&acceptCallback);
183
184   // Set up SSL context.
185   std::shared_ptr<SSLContext> sslContext(new SSLContext());
186   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
187   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
188   //sslContext->authenticate(true, false);
189
190   // connect
191   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
192                                                  sslContext);
193   socket->open();
194
195   // write()
196   uint8_t buf[128];
197   memset(buf, 'a', sizeof(buf));
198   socket->write(buf, sizeof(buf));
199
200   // read()
201   uint8_t readbuf[128];
202   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
203   EXPECT_EQ(bytesRead, 128);
204   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
205
206   // close()
207   socket->close();
208
209   cerr << "ConnectWriteReadClose test completed" << endl;
210 }
211
212 /**
213  * Test reading after server close.
214  */
215 TEST(AsyncSSLSocketTest, ReadAfterClose) {
216   // Start listening on a local port
217   WriteCallbackBase writeCallback;
218   ReadEOFCallback readCallback(&writeCallback);
219   HandshakeCallback handshakeCallback(&readCallback);
220   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
221   auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
222
223   // Set up SSL context.
224   auto sslContext = std::make_shared<SSLContext>();
225   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
226
227   auto socket =
228       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
229   socket->open();
230
231   // This should trigger an EOF on the client.
232   auto evb = handshakeCallback.getSocket()->getEventBase();
233   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
234   std::array<uint8_t, 128> readbuf;
235   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
236   EXPECT_EQ(0, bytesRead);
237 }
238
239 /**
240  * Test bad renegotiation
241  */
242 #if !defined(OPENSSL_IS_BORINGSSL)
243 TEST(AsyncSSLSocketTest, Renegotiate) {
244   EventBase eventBase;
245   auto clientCtx = std::make_shared<SSLContext>();
246   auto dfServerCtx = std::make_shared<SSLContext>();
247   std::array<int, 2> fds;
248   getfds(fds.data());
249   getctx(clientCtx, dfServerCtx);
250
251   AsyncSSLSocket::UniquePtr clientSock(
252       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
253   AsyncSSLSocket::UniquePtr serverSock(
254       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
255   SSLHandshakeClient client(std::move(clientSock), true, true);
256   RenegotiatingServer server(std::move(serverSock));
257
258   while (!client.handshakeSuccess_ && !client.handshakeError_) {
259     eventBase.loopOnce();
260   }
261
262   ASSERT_TRUE(client.handshakeSuccess_);
263
264   auto sslSock = std::move(client).moveSocket();
265   sslSock->detachEventBase();
266   // This is nasty, however we don't want to add support for
267   // renegotiation in AsyncSSLSocket.
268   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
269
270   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
271
272   std::thread t([&]() { eventBase.loopForever(); });
273
274   // Trigger the renegotiation.
275   std::array<uint8_t, 128> buf;
276   memset(buf.data(), 'a', buf.size());
277   try {
278     socket->write(buf.data(), buf.size());
279   } catch (AsyncSocketException& e) {
280     LOG(INFO) << "client got error " << e.what();
281   }
282   eventBase.terminateLoopSoon();
283   t.join();
284
285   eventBase.loop();
286   ASSERT_TRUE(server.renegotiationError_);
287 }
288 #endif
289
290 /**
291  * Negative test for handshakeError().
292  */
293 TEST(AsyncSSLSocketTest, HandshakeError) {
294   // Start listening on a local port
295   WriteCallbackBase writeCallback;
296   WriteErrorCallback readCallback(&writeCallback);
297   HandshakeCallback handshakeCallback(&readCallback);
298   HandshakeErrorCallback acceptCallback(&handshakeCallback);
299   TestSSLServer server(&acceptCallback);
300
301   // Set up SSL context.
302   std::shared_ptr<SSLContext> sslContext(new SSLContext());
303   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
304
305   // connect
306   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
307                                                  sslContext);
308   // read()
309   bool ex = false;
310   try {
311     socket->open();
312
313     uint8_t readbuf[128];
314     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
315     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
316   } catch (AsyncSocketException&) {
317     ex = true;
318   }
319   EXPECT_TRUE(ex);
320
321   // close()
322   socket->close();
323   cerr << "HandshakeError test completed" << endl;
324 }
325
326 /**
327  * Negative test for readError().
328  */
329 TEST(AsyncSSLSocketTest, ReadError) {
330   // Start listening on a local port
331   WriteCallbackBase writeCallback;
332   ReadErrorCallback readCallback(&writeCallback);
333   HandshakeCallback handshakeCallback(&readCallback);
334   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
335   TestSSLServer server(&acceptCallback);
336
337   // Set up SSL context.
338   std::shared_ptr<SSLContext> sslContext(new SSLContext());
339   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
340
341   // connect
342   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
343                                                  sslContext);
344   socket->open();
345
346   // write something to trigger ssl handshake
347   uint8_t buf[128];
348   memset(buf, 'a', sizeof(buf));
349   socket->write(buf, sizeof(buf));
350
351   socket->close();
352   cerr << "ReadError test completed" << endl;
353 }
354
355 /**
356  * Negative test for writeError().
357  */
358 TEST(AsyncSSLSocketTest, WriteError) {
359   // Start listening on a local port
360   WriteCallbackBase writeCallback;
361   WriteErrorCallback readCallback(&writeCallback);
362   HandshakeCallback handshakeCallback(&readCallback);
363   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
364   TestSSLServer server(&acceptCallback);
365
366   // Set up SSL context.
367   std::shared_ptr<SSLContext> sslContext(new SSLContext());
368   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
369
370   // connect
371   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
372                                                  sslContext);
373   socket->open();
374
375   // write something to trigger ssl handshake
376   uint8_t buf[128];
377   memset(buf, 'a', sizeof(buf));
378   socket->write(buf, sizeof(buf));
379
380   socket->close();
381   cerr << "WriteError test completed" << endl;
382 }
383
384 /**
385  * Test a socket with TCP_NODELAY unset.
386  */
387 TEST(AsyncSSLSocketTest, SocketWithDelay) {
388   // Start listening on a local port
389   WriteCallbackBase writeCallback;
390   ReadCallback readCallback(&writeCallback);
391   HandshakeCallback handshakeCallback(&readCallback);
392   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
393   TestSSLServer server(&acceptCallback);
394
395   // Set up SSL context.
396   std::shared_ptr<SSLContext> sslContext(new SSLContext());
397   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
398
399   // connect
400   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
401                                                  sslContext);
402   socket->open();
403
404   // write()
405   uint8_t buf[128];
406   memset(buf, 'a', sizeof(buf));
407   socket->write(buf, sizeof(buf));
408
409   // read()
410   uint8_t readbuf[128];
411   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
412   EXPECT_EQ(bytesRead, 128);
413   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
414
415   // close()
416   socket->close();
417
418   cerr << "SocketWithDelay test completed" << endl;
419 }
420
421 using NextProtocolTypePair =
422     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
423
424 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
425   // For matching protos
426  public:
427   void SetUp() override { getctx(clientCtx, serverCtx); }
428
429   void connect(bool unset = false) {
430     getfds(fds);
431
432     if (unset) {
433       // unsetting NPN for any of [client, server] is enough to make NPN not
434       // work
435       clientCtx->unsetNextProtocols();
436     }
437
438     AsyncSSLSocket::UniquePtr clientSock(
439       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
440     AsyncSSLSocket::UniquePtr serverSock(
441       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
442     client = folly::make_unique<NpnClient>(std::move(clientSock));
443     server = folly::make_unique<NpnServer>(std::move(serverSock));
444
445     eventBase.loop();
446   }
447
448   void expectProtocol(const std::string& proto) {
449     EXPECT_NE(client->nextProtoLength, 0);
450     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
451     EXPECT_EQ(
452         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
453         0);
454     string selected((const char*)client->nextProto, client->nextProtoLength);
455     EXPECT_EQ(proto, selected);
456   }
457
458   void expectNoProtocol() {
459     EXPECT_EQ(client->nextProtoLength, 0);
460     EXPECT_EQ(server->nextProtoLength, 0);
461     EXPECT_EQ(client->nextProto, nullptr);
462     EXPECT_EQ(server->nextProto, nullptr);
463   }
464
465   void expectProtocolType() {
466     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
467         GetParam().second == SSLContext::NextProtocolType::ANY) {
468       EXPECT_EQ(client->protocolType, server->protocolType);
469     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
470                GetParam().second == SSLContext::NextProtocolType::ANY) {
471       // Well not much we can say
472     } else {
473       expectProtocolType(GetParam());
474     }
475   }
476
477   void expectProtocolType(NextProtocolTypePair expected) {
478     EXPECT_EQ(client->protocolType, expected.first);
479     EXPECT_EQ(server->protocolType, expected.second);
480   }
481
482   EventBase eventBase;
483   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
484   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
485   int fds[2];
486   std::unique_ptr<NpnClient> client;
487   std::unique_ptr<NpnServer> server;
488 };
489
490 class NextProtocolTLSExtTest : public NextProtocolTest {
491   // For extended TLS protos
492 };
493
494 class NextProtocolNPNOnlyTest : public NextProtocolTest {
495   // For mismatching protos
496 };
497
498 class NextProtocolMismatchTest : public NextProtocolTest {
499   // For mismatching protos
500 };
501
502 TEST_P(NextProtocolTest, NpnTestOverlap) {
503   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
504   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
505                                         GetParam().second);
506
507   connect();
508
509   expectProtocol("baz");
510   expectProtocolType();
511 }
512
513 TEST_P(NextProtocolTest, NpnTestUnset) {
514   // Identical to above test, except that we want unset NPN before
515   // looping.
516   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
517   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
518                                         GetParam().second);
519
520   connect(true /* unset */);
521
522   // if alpn negotiation fails, type will appear as npn
523   expectNoProtocol();
524   EXPECT_EQ(client->protocolType, server->protocolType);
525 }
526
527 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
528   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
529   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
530                                         GetParam().second);
531
532   connect();
533
534   expectNoProtocol();
535   expectProtocolType(
536       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
537 }
538
539 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
540 // will fail on 1.0.2 before that.
541 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
542   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
543   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
544                                         GetParam().second);
545
546   connect();
547
548   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
549       GetParam().second == SSLContext::NextProtocolType::ALPN) {
550     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
551     // mismatch should result in a fatal alert, but this is OpenSSL's current
552     // behavior and we want to know if it changes.
553     expectNoProtocol();
554   }
555 #if defined(OPENSSL_IS_BORINGSSL)
556   // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
557   // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
558   // NPN *and* ALPN
559   else if (
560       GetParam().first == SSLContext::NextProtocolType::ANY &&
561       GetParam().second == SSLContext::NextProtocolType::ANY) {
562     expectNoProtocol();
563   }
564 #endif
565   else {
566     expectProtocol("blub");
567     expectProtocolType(
568         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
569   }
570 }
571
572 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
573   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
574   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
575   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
576                                         GetParam().second);
577
578   connect();
579
580   expectProtocol("ponies");
581   expectProtocolType();
582 }
583
584 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
585   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
586   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
587   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
588                                         GetParam().second);
589
590   connect();
591
592   expectProtocol("blub");
593   expectProtocolType();
594 }
595
596 TEST_P(NextProtocolTest, RandomizedNpnTest) {
597   // Probability that this test will fail is 2^-64, which could be considered
598   // as negligible.
599   const int kTries = 64;
600
601   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
602                                         GetParam().first);
603   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
604                                                   GetParam().second);
605
606   std::set<string> selectedProtocols;
607   for (int i = 0; i < kTries; ++i) {
608     connect();
609
610     EXPECT_NE(client->nextProtoLength, 0);
611     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
612     EXPECT_EQ(
613         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
614         0);
615     string selected((const char*)client->nextProto, client->nextProtoLength);
616     selectedProtocols.insert(selected);
617     expectProtocolType();
618   }
619   EXPECT_EQ(selectedProtocols.size(), 2);
620 }
621
622 INSTANTIATE_TEST_CASE_P(
623     AsyncSSLSocketTest,
624     NextProtocolTest,
625     ::testing::Values(
626         NextProtocolTypePair(
627             SSLContext::NextProtocolType::NPN,
628             SSLContext::NextProtocolType::NPN),
629         NextProtocolTypePair(
630             SSLContext::NextProtocolType::NPN,
631             SSLContext::NextProtocolType::ANY),
632         NextProtocolTypePair(
633             SSLContext::NextProtocolType::ANY,
634             SSLContext::NextProtocolType::ANY)));
635
636 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
637 INSTANTIATE_TEST_CASE_P(
638     AsyncSSLSocketTest,
639     NextProtocolTLSExtTest,
640     ::testing::Values(
641         NextProtocolTypePair(
642             SSLContext::NextProtocolType::ALPN,
643             SSLContext::NextProtocolType::ALPN),
644         NextProtocolTypePair(
645             SSLContext::NextProtocolType::ALPN,
646             SSLContext::NextProtocolType::ANY),
647         NextProtocolTypePair(
648             SSLContext::NextProtocolType::ANY,
649             SSLContext::NextProtocolType::ALPN)));
650 #endif
651
652 INSTANTIATE_TEST_CASE_P(
653     AsyncSSLSocketTest,
654     NextProtocolNPNOnlyTest,
655     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
656                                            SSLContext::NextProtocolType::NPN)));
657
658 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
659 INSTANTIATE_TEST_CASE_P(
660     AsyncSSLSocketTest,
661     NextProtocolMismatchTest,
662     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
663                                            SSLContext::NextProtocolType::ALPN),
664                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
665                                            SSLContext::NextProtocolType::NPN)));
666 #endif
667
668 #ifndef OPENSSL_NO_TLSEXT
669 /**
670  * 1. Client sends TLSEXT_HOSTNAME in client hello.
671  * 2. Server found a match SSL_CTX and use this SSL_CTX to
672  *    continue the SSL handshake.
673  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
674  */
675 TEST(AsyncSSLSocketTest, SNITestMatch) {
676   EventBase eventBase;
677   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
678   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
679   // Use the same SSLContext to continue the handshake after
680   // tlsext_hostname match.
681   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
682   const std::string serverName("xyz.newdev.facebook.com");
683   int fds[2];
684   getfds(fds);
685   getctx(clientCtx, dfServerCtx);
686
687   AsyncSSLSocket::UniquePtr clientSock(
688     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
689   AsyncSSLSocket::UniquePtr serverSock(
690     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
691   SNIClient client(std::move(clientSock));
692   SNIServer server(std::move(serverSock),
693                    dfServerCtx,
694                    hskServerCtx,
695                    serverName);
696
697   eventBase.loop();
698
699   EXPECT_TRUE(client.serverNameMatch);
700   EXPECT_TRUE(server.serverNameMatch);
701 }
702
703 /**
704  * 1. Client sends TLSEXT_HOSTNAME in client hello.
705  * 2. Server cannot find a matching SSL_CTX and continue to use
706  *    the current SSL_CTX to do the handshake.
707  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
708  */
709 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
710   EventBase eventBase;
711   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
712   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
713   // Use the same SSLContext to continue the handshake after
714   // tlsext_hostname match.
715   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
716   const std::string clientRequestingServerName("foo.com");
717   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
718
719   int fds[2];
720   getfds(fds);
721   getctx(clientCtx, dfServerCtx);
722
723   AsyncSSLSocket::UniquePtr clientSock(
724     new AsyncSSLSocket(clientCtx,
725                         &eventBase,
726                         fds[0],
727                         clientRequestingServerName));
728   AsyncSSLSocket::UniquePtr serverSock(
729     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
730   SNIClient client(std::move(clientSock));
731   SNIServer server(std::move(serverSock),
732                    dfServerCtx,
733                    hskServerCtx,
734                    serverExpectedServerName);
735
736   eventBase.loop();
737
738   EXPECT_TRUE(!client.serverNameMatch);
739   EXPECT_TRUE(!server.serverNameMatch);
740 }
741 /**
742  * 1. Client sends TLSEXT_HOSTNAME in client hello.
743  * 2. We then change the serverName.
744  * 3. We expect that we get 'false' as the result for serNameMatch.
745  */
746
747 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
748    EventBase eventBase;
749   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
750   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
751   // Use the same SSLContext to continue the handshake after
752   // tlsext_hostname match.
753   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
754   const std::string serverName("xyz.newdev.facebook.com");
755   int fds[2];
756   getfds(fds);
757   getctx(clientCtx, dfServerCtx);
758
759   AsyncSSLSocket::UniquePtr clientSock(
760     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
761   //Change the server name
762   std::string newName("new.com");
763   clientSock->setServerName(newName);
764   AsyncSSLSocket::UniquePtr serverSock(
765     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
766   SNIClient client(std::move(clientSock));
767   SNIServer server(std::move(serverSock),
768                    dfServerCtx,
769                    hskServerCtx,
770                    serverName);
771
772   eventBase.loop();
773
774   EXPECT_TRUE(!client.serverNameMatch);
775 }
776
777 /**
778  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
779  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
780  */
781 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
782   EventBase eventBase;
783   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
784   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
785   // Use the same SSLContext to continue the handshake after
786   // tlsext_hostname match.
787   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
788   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
789
790   int fds[2];
791   getfds(fds);
792   getctx(clientCtx, dfServerCtx);
793
794   AsyncSSLSocket::UniquePtr clientSock(
795     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
796   AsyncSSLSocket::UniquePtr serverSock(
797     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
798   SNIClient client(std::move(clientSock));
799   SNIServer server(std::move(serverSock),
800                    dfServerCtx,
801                    hskServerCtx,
802                    serverExpectedServerName);
803
804   eventBase.loop();
805
806   EXPECT_TRUE(!client.serverNameMatch);
807   EXPECT_TRUE(!server.serverNameMatch);
808 }
809
810 #endif
811 /**
812  * Test SSL client socket
813  */
814 TEST(AsyncSSLSocketTest, SSLClientTest) {
815   // Start listening on a local port
816   WriteCallbackBase writeCallback;
817   ReadCallback readCallback(&writeCallback);
818   HandshakeCallback handshakeCallback(&readCallback);
819   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
820   TestSSLServer server(&acceptCallback);
821
822   // Set up SSL client
823   EventBase eventBase;
824   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
825
826   client->connect();
827   EventBaseAborter eba(&eventBase, 3000);
828   eventBase.loop();
829
830   EXPECT_EQ(client->getMiss(), 1);
831   EXPECT_EQ(client->getHit(), 0);
832
833   cerr << "SSLClientTest test completed" << endl;
834 }
835
836
837 /**
838  * Test SSL client socket session re-use
839  */
840 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
841   // Start listening on a local port
842   WriteCallbackBase writeCallback;
843   ReadCallback readCallback(&writeCallback);
844   HandshakeCallback handshakeCallback(&readCallback);
845   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
846   TestSSLServer server(&acceptCallback);
847
848   // Set up SSL client
849   EventBase eventBase;
850   auto client =
851       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
852
853   client->connect();
854   EventBaseAborter eba(&eventBase, 3000);
855   eventBase.loop();
856
857   EXPECT_EQ(client->getMiss(), 1);
858   EXPECT_EQ(client->getHit(), 9);
859
860   cerr << "SSLClientTestReuse test completed" << endl;
861 }
862
863 /**
864  * Test SSL client socket timeout
865  */
866 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
867   // Start listening on a local port
868   EmptyReadCallback readCallback;
869   HandshakeCallback handshakeCallback(&readCallback,
870                                       HandshakeCallback::EXPECT_ERROR);
871   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
872   TestSSLServer server(&acceptCallback);
873
874   // Set up SSL client
875   EventBase eventBase;
876   auto client =
877       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
878   client->connect(true /* write before connect completes */);
879   EventBaseAborter eba(&eventBase, 3000);
880   eventBase.loop();
881
882   usleep(100000);
883   // This is checking that the connectError callback precedes any queued
884   // writeError callbacks.  This matches AsyncSocket's behavior
885   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
886   EXPECT_EQ(client->getErrors(), 1);
887   EXPECT_EQ(client->getMiss(), 0);
888   EXPECT_EQ(client->getHit(), 0);
889
890   cerr << "SSLClientTimeoutTest test completed" << endl;
891 }
892
893 // The next 3 tests need an FB-only extension, and will fail without it
894 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
895 /**
896  * Test SSL server async cache
897  */
898 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
899   // Start listening on a local port
900   WriteCallbackBase writeCallback;
901   ReadCallback readCallback(&writeCallback);
902   HandshakeCallback handshakeCallback(&readCallback);
903   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
904   TestSSLAsyncCacheServer server(&acceptCallback);
905
906   // Set up SSL client
907   EventBase eventBase;
908   auto client =
909       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
910
911   client->connect();
912   EventBaseAborter eba(&eventBase, 3000);
913   eventBase.loop();
914
915   EXPECT_EQ(server.getAsyncCallbacks(), 18);
916   EXPECT_EQ(server.getAsyncLookups(), 9);
917   EXPECT_EQ(client->getMiss(), 10);
918   EXPECT_EQ(client->getHit(), 0);
919
920   cerr << "SSLServerAsyncCacheTest test completed" << endl;
921 }
922
923 /**
924  * Test SSL server accept timeout with cache path
925  */
926 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
927   // Start listening on a local port
928   WriteCallbackBase writeCallback;
929   ReadCallback readCallback(&writeCallback);
930   HandshakeCallback handshakeCallback(&readCallback);
931   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
932   TestSSLAsyncCacheServer server(&acceptCallback);
933
934   // Set up SSL client
935   EventBase eventBase;
936   // only do a TCP connect
937   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
938   sock->connect(nullptr, server.getAddress());
939
940   EmptyReadCallback clientReadCallback;
941   clientReadCallback.tcpSocket_ = sock;
942   sock->setReadCB(&clientReadCallback);
943
944   EventBaseAborter eba(&eventBase, 3000);
945   eventBase.loop();
946
947   EXPECT_EQ(readCallback.state, STATE_WAITING);
948
949   cerr << "SSLServerTimeoutTest test completed" << endl;
950 }
951
952 /**
953  * Test SSL server accept timeout with cache path
954  */
955 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
956   // Start listening on a local port
957   WriteCallbackBase writeCallback;
958   ReadCallback readCallback(&writeCallback);
959   HandshakeCallback handshakeCallback(&readCallback);
960   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
961   TestSSLAsyncCacheServer server(&acceptCallback);
962
963   // Set up SSL client
964   EventBase eventBase;
965   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
966
967   client->connect();
968   EventBaseAborter eba(&eventBase, 3000);
969   eventBase.loop();
970
971   EXPECT_EQ(server.getAsyncCallbacks(), 1);
972   EXPECT_EQ(server.getAsyncLookups(), 1);
973   EXPECT_EQ(client->getErrors(), 1);
974   EXPECT_EQ(client->getMiss(), 1);
975   EXPECT_EQ(client->getHit(), 0);
976
977   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
978 }
979
980 /**
981  * Test SSL server accept timeout with cache path
982  */
983 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
984   // Start listening on a local port
985   WriteCallbackBase writeCallback;
986   ReadCallback readCallback(&writeCallback);
987   HandshakeCallback handshakeCallback(&readCallback,
988                                       HandshakeCallback::EXPECT_ERROR);
989   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
990   TestSSLAsyncCacheServer server(&acceptCallback, 500);
991
992   // Set up SSL client
993   EventBase eventBase;
994   auto client =
995       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
996
997   client->connect();
998   EventBaseAborter eba(&eventBase, 3000);
999   eventBase.loop();
1000
1001   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
1002       handshakeCallback.closeSocket();});
1003   // give time for the cache lookup to come back and find it closed
1004   handshakeCallback.waitForHandshake();
1005
1006   EXPECT_EQ(server.getAsyncCallbacks(), 1);
1007   EXPECT_EQ(server.getAsyncLookups(), 1);
1008   EXPECT_EQ(client->getErrors(), 1);
1009   EXPECT_EQ(client->getMiss(), 1);
1010   EXPECT_EQ(client->getHit(), 0);
1011
1012   cerr << "SSLServerCacheCloseTest test completed" << endl;
1013 }
1014 #endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
1015
1016 /**
1017  * Verify Client Ciphers obtained using SSL MSG Callback.
1018  */
1019 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1020   EventBase eventBase;
1021   auto clientCtx = std::make_shared<SSLContext>();
1022   auto serverCtx = std::make_shared<SSLContext>();
1023   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1024   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
1025   serverCtx->loadPrivateKey(testKey);
1026   serverCtx->loadCertificate(testCert);
1027   serverCtx->loadTrustedCertificates(testCA);
1028   serverCtx->loadClientCAList(testCA);
1029
1030   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1031   clientCtx->ciphers("AES256-SHA:AES128-SHA");
1032   clientCtx->loadPrivateKey(testKey);
1033   clientCtx->loadCertificate(testCert);
1034   clientCtx->loadTrustedCertificates(testCA);
1035
1036   int fds[2];
1037   getfds(fds);
1038
1039   AsyncSSLSocket::UniquePtr clientSock(
1040       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1041   AsyncSSLSocket::UniquePtr serverSock(
1042       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1043
1044   SSLHandshakeClient client(std::move(clientSock), true, true);
1045   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1046
1047   eventBase.loop();
1048
1049 #if defined(OPENSSL_IS_BORINGSSL)
1050   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
1051 #else
1052   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
1053 #endif
1054   EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1055   EXPECT_TRUE(client.handshakeVerify_);
1056   EXPECT_TRUE(client.handshakeSuccess_);
1057   EXPECT_TRUE(!client.handshakeError_);
1058   EXPECT_TRUE(server.handshakeVerify_);
1059   EXPECT_TRUE(server.handshakeSuccess_);
1060   EXPECT_TRUE(!server.handshakeError_);
1061 }
1062
1063 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1064   EventBase eventBase;
1065   auto ctx = std::make_shared<SSLContext>();
1066
1067   int fds[2];
1068   getfds(fds);
1069
1070   int bufLen = 42;
1071   uint8_t majorVersion = 18;
1072   uint8_t minorVersion = 25;
1073
1074   // Create callback buf
1075   auto buf = IOBuf::create(bufLen);
1076   buf->append(bufLen);
1077   folly::io::RWPrivateCursor cursor(buf.get());
1078   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1079   cursor.write<uint16_t>(0);
1080   cursor.write<uint8_t>(38);
1081   cursor.write<uint8_t>(majorVersion);
1082   cursor.write<uint8_t>(minorVersion);
1083   cursor.skip(32);
1084   cursor.write<uint32_t>(0);
1085
1086   SSL* ssl = ctx->createSSL();
1087   SCOPE_EXIT { SSL_free(ssl); };
1088   AsyncSSLSocket::UniquePtr sock(
1089       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1090   sock->enableClientHelloParsing();
1091
1092   // Test client hello parsing in one packet
1093   AsyncSSLSocket::clientHelloParsingCallback(
1094       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1095   buf.reset();
1096
1097   auto parsedClientHello = sock->getClientHelloInfo();
1098   EXPECT_TRUE(parsedClientHello != nullptr);
1099   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1100   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1101 }
1102
1103 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1104   EventBase eventBase;
1105   auto ctx = std::make_shared<SSLContext>();
1106
1107   int fds[2];
1108   getfds(fds);
1109
1110   int bufLen = 42;
1111   uint8_t majorVersion = 18;
1112   uint8_t minorVersion = 25;
1113
1114   // Create callback buf
1115   auto buf = IOBuf::create(bufLen);
1116   buf->append(bufLen);
1117   folly::io::RWPrivateCursor cursor(buf.get());
1118   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1119   cursor.write<uint16_t>(0);
1120   cursor.write<uint8_t>(38);
1121   cursor.write<uint8_t>(majorVersion);
1122   cursor.write<uint8_t>(minorVersion);
1123   cursor.skip(32);
1124   cursor.write<uint32_t>(0);
1125
1126   SSL* ssl = ctx->createSSL();
1127   SCOPE_EXIT { SSL_free(ssl); };
1128   AsyncSSLSocket::UniquePtr sock(
1129       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1130   sock->enableClientHelloParsing();
1131
1132   // Test parsing with two packets with first packet size < 3
1133   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1134   AsyncSSLSocket::clientHelloParsingCallback(
1135       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1136       ssl, sock.get());
1137   bufCopy.reset();
1138   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1139   AsyncSSLSocket::clientHelloParsingCallback(
1140       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1141       ssl, sock.get());
1142   bufCopy.reset();
1143
1144   auto parsedClientHello = sock->getClientHelloInfo();
1145   EXPECT_TRUE(parsedClientHello != nullptr);
1146   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1147   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1148 }
1149
1150 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1151   EventBase eventBase;
1152   auto ctx = std::make_shared<SSLContext>();
1153
1154   int fds[2];
1155   getfds(fds);
1156
1157   int bufLen = 42;
1158   uint8_t majorVersion = 18;
1159   uint8_t minorVersion = 25;
1160
1161   // Create callback buf
1162   auto buf = IOBuf::create(bufLen);
1163   buf->append(bufLen);
1164   folly::io::RWPrivateCursor cursor(buf.get());
1165   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1166   cursor.write<uint16_t>(0);
1167   cursor.write<uint8_t>(38);
1168   cursor.write<uint8_t>(majorVersion);
1169   cursor.write<uint8_t>(minorVersion);
1170   cursor.skip(32);
1171   cursor.write<uint32_t>(0);
1172
1173   SSL* ssl = ctx->createSSL();
1174   SCOPE_EXIT { SSL_free(ssl); };
1175   AsyncSSLSocket::UniquePtr sock(
1176       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1177   sock->enableClientHelloParsing();
1178
1179   // Test parsing with multiple small packets
1180   for (uint64_t i = 0; i < buf->length(); i += 3) {
1181     auto bufCopy = folly::IOBuf::copyBuffer(
1182         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1183     AsyncSSLSocket::clientHelloParsingCallback(
1184         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1185         ssl, sock.get());
1186     bufCopy.reset();
1187   }
1188
1189   auto parsedClientHello = sock->getClientHelloInfo();
1190   EXPECT_TRUE(parsedClientHello != nullptr);
1191   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1192   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1193 }
1194
1195 /**
1196  * Verify sucessful behavior of SSL certificate validation.
1197  */
1198 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1199   EventBase eventBase;
1200   auto clientCtx = std::make_shared<SSLContext>();
1201   auto dfServerCtx = std::make_shared<SSLContext>();
1202
1203   int fds[2];
1204   getfds(fds);
1205   getctx(clientCtx, dfServerCtx);
1206
1207   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1208   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1209
1210   AsyncSSLSocket::UniquePtr clientSock(
1211     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1212   AsyncSSLSocket::UniquePtr serverSock(
1213     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1214
1215   SSLHandshakeClient client(std::move(clientSock), true, true);
1216   clientCtx->loadTrustedCertificates(testCA);
1217
1218   SSLHandshakeServer server(std::move(serverSock), true, true);
1219
1220   eventBase.loop();
1221
1222   EXPECT_TRUE(client.handshakeVerify_);
1223   EXPECT_TRUE(client.handshakeSuccess_);
1224   EXPECT_TRUE(!client.handshakeError_);
1225   EXPECT_LE(0, client.handshakeTime.count());
1226   EXPECT_TRUE(!server.handshakeVerify_);
1227   EXPECT_TRUE(server.handshakeSuccess_);
1228   EXPECT_TRUE(!server.handshakeError_);
1229   EXPECT_LE(0, server.handshakeTime.count());
1230 }
1231
1232 /**
1233  * Verify that the client's verification callback is able to fail SSL
1234  * connection establishment.
1235  */
1236 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1237   EventBase eventBase;
1238   auto clientCtx = std::make_shared<SSLContext>();
1239   auto dfServerCtx = std::make_shared<SSLContext>();
1240
1241   int fds[2];
1242   getfds(fds);
1243   getctx(clientCtx, dfServerCtx);
1244
1245   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1246   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1247
1248   AsyncSSLSocket::UniquePtr clientSock(
1249     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1250   AsyncSSLSocket::UniquePtr serverSock(
1251     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1252
1253   SSLHandshakeClient client(std::move(clientSock), true, false);
1254   clientCtx->loadTrustedCertificates(testCA);
1255
1256   SSLHandshakeServer server(std::move(serverSock), true, true);
1257
1258   eventBase.loop();
1259
1260   EXPECT_TRUE(client.handshakeVerify_);
1261   EXPECT_TRUE(!client.handshakeSuccess_);
1262   EXPECT_TRUE(client.handshakeError_);
1263   EXPECT_LE(0, client.handshakeTime.count());
1264   EXPECT_TRUE(!server.handshakeVerify_);
1265   EXPECT_TRUE(!server.handshakeSuccess_);
1266   EXPECT_TRUE(server.handshakeError_);
1267   EXPECT_LE(0, server.handshakeTime.count());
1268 }
1269
1270 /**
1271  * Verify that the options in SSLContext can be overridden in
1272  * sslConnect/Accept.i.e specifying that no validation should be performed
1273  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1274  * the validation callback.
1275  */
1276 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1277   EventBase eventBase;
1278   auto clientCtx = std::make_shared<SSLContext>();
1279   auto dfServerCtx = std::make_shared<SSLContext>();
1280
1281   int fds[2];
1282   getfds(fds);
1283   getctx(clientCtx, dfServerCtx);
1284
1285   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1286   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1287
1288   AsyncSSLSocket::UniquePtr clientSock(
1289     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1290   AsyncSSLSocket::UniquePtr serverSock(
1291     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1292
1293   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1294   clientCtx->loadTrustedCertificates(testCA);
1295
1296   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1297
1298   eventBase.loop();
1299
1300   EXPECT_TRUE(!client.handshakeVerify_);
1301   EXPECT_TRUE(client.handshakeSuccess_);
1302   EXPECT_TRUE(!client.handshakeError_);
1303   EXPECT_LE(0, client.handshakeTime.count());
1304   EXPECT_TRUE(!server.handshakeVerify_);
1305   EXPECT_TRUE(server.handshakeSuccess_);
1306   EXPECT_TRUE(!server.handshakeError_);
1307   EXPECT_LE(0, server.handshakeTime.count());
1308 }
1309
1310 /**
1311  * Verify that the options in SSLContext can be overridden in
1312  * sslConnect/Accept. Enable verification even if context says otherwise.
1313  * Test requireClientCert with client cert
1314  */
1315 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1316   EventBase eventBase;
1317   auto clientCtx = std::make_shared<SSLContext>();
1318   auto serverCtx = std::make_shared<SSLContext>();
1319   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1320   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1321   serverCtx->loadPrivateKey(testKey);
1322   serverCtx->loadCertificate(testCert);
1323   serverCtx->loadTrustedCertificates(testCA);
1324   serverCtx->loadClientCAList(testCA);
1325
1326   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1327   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1328   clientCtx->loadPrivateKey(testKey);
1329   clientCtx->loadCertificate(testCert);
1330   clientCtx->loadTrustedCertificates(testCA);
1331
1332   int fds[2];
1333   getfds(fds);
1334
1335   AsyncSSLSocket::UniquePtr clientSock(
1336       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1337   AsyncSSLSocket::UniquePtr serverSock(
1338       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1339
1340   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1341   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1342
1343   eventBase.loop();
1344
1345   EXPECT_TRUE(client.handshakeVerify_);
1346   EXPECT_TRUE(client.handshakeSuccess_);
1347   EXPECT_FALSE(client.handshakeError_);
1348   EXPECT_LE(0, client.handshakeTime.count());
1349   EXPECT_TRUE(server.handshakeVerify_);
1350   EXPECT_TRUE(server.handshakeSuccess_);
1351   EXPECT_FALSE(server.handshakeError_);
1352   EXPECT_LE(0, server.handshakeTime.count());
1353 }
1354
1355 /**
1356  * Verify that the client's verification callback is able to override
1357  * the preverification failure and allow a successful connection.
1358  */
1359 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1360   EventBase eventBase;
1361   auto clientCtx = std::make_shared<SSLContext>();
1362   auto dfServerCtx = std::make_shared<SSLContext>();
1363
1364   int fds[2];
1365   getfds(fds);
1366   getctx(clientCtx, dfServerCtx);
1367
1368   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1369   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1370
1371   AsyncSSLSocket::UniquePtr clientSock(
1372     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1373   AsyncSSLSocket::UniquePtr serverSock(
1374     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1375
1376   SSLHandshakeClient client(std::move(clientSock), false, true);
1377   SSLHandshakeServer server(std::move(serverSock), true, true);
1378
1379   eventBase.loop();
1380
1381   EXPECT_TRUE(client.handshakeVerify_);
1382   EXPECT_TRUE(client.handshakeSuccess_);
1383   EXPECT_TRUE(!client.handshakeError_);
1384   EXPECT_LE(0, client.handshakeTime.count());
1385   EXPECT_TRUE(!server.handshakeVerify_);
1386   EXPECT_TRUE(server.handshakeSuccess_);
1387   EXPECT_TRUE(!server.handshakeError_);
1388   EXPECT_LE(0, server.handshakeTime.count());
1389 }
1390
1391 /**
1392  * Verify that specifying that no validation should be performed allows an
1393  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1394  * callback.
1395  */
1396 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1397   EventBase eventBase;
1398   auto clientCtx = std::make_shared<SSLContext>();
1399   auto dfServerCtx = std::make_shared<SSLContext>();
1400
1401   int fds[2];
1402   getfds(fds);
1403   getctx(clientCtx, dfServerCtx);
1404
1405   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1406   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1407
1408   AsyncSSLSocket::UniquePtr clientSock(
1409     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1410   AsyncSSLSocket::UniquePtr serverSock(
1411     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1412
1413   SSLHandshakeClient client(std::move(clientSock), false, false);
1414   SSLHandshakeServer server(std::move(serverSock), false, false);
1415
1416   eventBase.loop();
1417
1418   EXPECT_TRUE(!client.handshakeVerify_);
1419   EXPECT_TRUE(client.handshakeSuccess_);
1420   EXPECT_TRUE(!client.handshakeError_);
1421   EXPECT_LE(0, client.handshakeTime.count());
1422   EXPECT_TRUE(!server.handshakeVerify_);
1423   EXPECT_TRUE(server.handshakeSuccess_);
1424   EXPECT_TRUE(!server.handshakeError_);
1425   EXPECT_LE(0, server.handshakeTime.count());
1426 }
1427
1428 /**
1429  * Test requireClientCert with client cert
1430  */
1431 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1432   EventBase eventBase;
1433   auto clientCtx = std::make_shared<SSLContext>();
1434   auto serverCtx = std::make_shared<SSLContext>();
1435   serverCtx->setVerificationOption(
1436       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1437   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1438   serverCtx->loadPrivateKey(testKey);
1439   serverCtx->loadCertificate(testCert);
1440   serverCtx->loadTrustedCertificates(testCA);
1441   serverCtx->loadClientCAList(testCA);
1442
1443   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1444   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1445   clientCtx->loadPrivateKey(testKey);
1446   clientCtx->loadCertificate(testCert);
1447   clientCtx->loadTrustedCertificates(testCA);
1448
1449   int fds[2];
1450   getfds(fds);
1451
1452   AsyncSSLSocket::UniquePtr clientSock(
1453       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1454   AsyncSSLSocket::UniquePtr serverSock(
1455       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1456
1457   SSLHandshakeClient client(std::move(clientSock), true, true);
1458   SSLHandshakeServer server(std::move(serverSock), true, true);
1459
1460   eventBase.loop();
1461
1462   EXPECT_TRUE(client.handshakeVerify_);
1463   EXPECT_TRUE(client.handshakeSuccess_);
1464   EXPECT_FALSE(client.handshakeError_);
1465   EXPECT_LE(0, client.handshakeTime.count());
1466   EXPECT_TRUE(server.handshakeVerify_);
1467   EXPECT_TRUE(server.handshakeSuccess_);
1468   EXPECT_FALSE(server.handshakeError_);
1469   EXPECT_LE(0, server.handshakeTime.count());
1470 }
1471
1472
1473 /**
1474  * Test requireClientCert with no client cert
1475  */
1476 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1477   EventBase eventBase;
1478   auto clientCtx = std::make_shared<SSLContext>();
1479   auto serverCtx = std::make_shared<SSLContext>();
1480   serverCtx->setVerificationOption(
1481       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1482   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1483   serverCtx->loadPrivateKey(testKey);
1484   serverCtx->loadCertificate(testCert);
1485   serverCtx->loadTrustedCertificates(testCA);
1486   serverCtx->loadClientCAList(testCA);
1487   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1488   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1489
1490   int fds[2];
1491   getfds(fds);
1492
1493   AsyncSSLSocket::UniquePtr clientSock(
1494       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1495   AsyncSSLSocket::UniquePtr serverSock(
1496       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1497
1498   SSLHandshakeClient client(std::move(clientSock), false, false);
1499   SSLHandshakeServer server(std::move(serverSock), false, false);
1500
1501   eventBase.loop();
1502
1503   EXPECT_FALSE(server.handshakeVerify_);
1504   EXPECT_FALSE(server.handshakeSuccess_);
1505   EXPECT_TRUE(server.handshakeError_);
1506   EXPECT_LE(0, client.handshakeTime.count());
1507   EXPECT_LE(0, server.handshakeTime.count());
1508 }
1509
1510 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1511   auto cert = getFileAsBuf(testCert);
1512   auto key = getFileAsBuf(testKey);
1513
1514   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1515   BIO_write(certBio.get(), cert.data(), cert.size());
1516   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1517   BIO_write(keyBio.get(), key.data(), key.size());
1518
1519   // Create SSL structs from buffers to get properties
1520   ssl::X509UniquePtr certStruct(
1521       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1522   ssl::EvpPkeyUniquePtr keyStruct(
1523       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1524   certBio = nullptr;
1525   keyBio = nullptr;
1526
1527   auto origCommonName = getCommonName(certStruct.get());
1528   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1529   certStruct = nullptr;
1530   keyStruct = nullptr;
1531
1532   auto ctx = std::make_shared<SSLContext>();
1533   ctx->loadPrivateKeyFromBufferPEM(key);
1534   ctx->loadCertificateFromBufferPEM(cert);
1535   ctx->loadTrustedCertificates(testCA);
1536
1537   ssl::SSLUniquePtr ssl(ctx->createSSL());
1538
1539   auto newCert = SSL_get_certificate(ssl.get());
1540   auto newKey = SSL_get_privatekey(ssl.get());
1541
1542   // Get properties from SSL struct
1543   auto newCommonName = getCommonName(newCert);
1544   auto newKeySize = EVP_PKEY_bits(newKey);
1545
1546   // Check that the key and cert have the expected properties
1547   EXPECT_EQ(origCommonName, newCommonName);
1548   EXPECT_EQ(origKeySize, newKeySize);
1549 }
1550
1551 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1552   EventBase eb;
1553
1554   // Set up SSL context.
1555   auto sslContext = std::make_shared<SSLContext>();
1556   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1557
1558   // create SSL socket
1559   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1560
1561   EXPECT_EQ(1500, socket->getMinWriteSize());
1562
1563   socket->setMinWriteSize(0);
1564   EXPECT_EQ(0, socket->getMinWriteSize());
1565   socket->setMinWriteSize(50000);
1566   EXPECT_EQ(50000, socket->getMinWriteSize());
1567 }
1568
1569 class ReadCallbackTerminator : public ReadCallback {
1570  public:
1571   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1572       : ReadCallback(wcb)
1573       , base_(base) {}
1574
1575   // Do not write data back, terminate the loop.
1576   void readDataAvailable(size_t len) noexcept override {
1577     std::cerr << "readDataAvailable, len " << len << std::endl;
1578
1579     currentBuffer.length = len;
1580
1581     buffers.push_back(currentBuffer);
1582     currentBuffer.reset();
1583     state = STATE_SUCCEEDED;
1584
1585     socket_->setReadCB(nullptr);
1586     base_->terminateLoopSoon();
1587   }
1588  private:
1589   EventBase* base_;
1590 };
1591
1592
1593 /**
1594  * Test a full unencrypted codepath
1595  */
1596 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1597   EventBase base;
1598
1599   auto clientCtx = std::make_shared<folly::SSLContext>();
1600   auto serverCtx = std::make_shared<folly::SSLContext>();
1601   int fds[2];
1602   getfds(fds);
1603   getctx(clientCtx, serverCtx);
1604   auto client = AsyncSSLSocket::newSocket(
1605                   clientCtx, &base, fds[0], false, true);
1606   auto server = AsyncSSLSocket::newSocket(
1607                   serverCtx, &base, fds[1], true, true);
1608
1609   ReadCallbackTerminator readCallback(&base, nullptr);
1610   server->setReadCB(&readCallback);
1611   readCallback.setSocket(server);
1612
1613   uint8_t buf[128];
1614   memset(buf, 'a', sizeof(buf));
1615   client->write(nullptr, buf, sizeof(buf));
1616
1617   // Check that bytes are unencrypted
1618   char c;
1619   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1620   EXPECT_EQ('a', c);
1621
1622   EventBaseAborter eba(&base, 3000);
1623   base.loop();
1624
1625   EXPECT_EQ(1, readCallback.buffers.size());
1626   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1627
1628   server->setReadCB(&readCallback);
1629
1630   // Unencrypted
1631   server->sslAccept(nullptr);
1632   client->sslConn(nullptr);
1633
1634   // Do NOT wait for handshake, writing should be queued and happen after
1635
1636   client->write(nullptr, buf, sizeof(buf));
1637
1638   // Check that bytes are *not* unencrypted
1639   char c2;
1640   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1641   EXPECT_NE('a', c2);
1642
1643
1644   base.loop();
1645
1646   EXPECT_EQ(2, readCallback.buffers.size());
1647   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1648 }
1649
1650 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1651   // Start listening on a local port
1652   WriteCallbackBase writeCallback;
1653   WriteErrorCallback readCallback(&writeCallback);
1654   HandshakeCallback handshakeCallback(&readCallback,
1655                                       HandshakeCallback::EXPECT_ERROR);
1656   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1657   TestSSLServer server(&acceptCallback);
1658
1659   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1660   socket->open();
1661   uint8_t buf[3] = {0x16, 0x03, 0x01};
1662   socket->write(buf, sizeof(buf));
1663   socket->closeWithReset();
1664
1665   handshakeCallback.waitForHandshake();
1666   EXPECT_NE(
1667       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1668   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1669 }
1670
1671 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1672   // Start listening on a local port
1673   WriteCallbackBase writeCallback;
1674   WriteErrorCallback readCallback(&writeCallback);
1675   HandshakeCallback handshakeCallback(&readCallback,
1676                                       HandshakeCallback::EXPECT_ERROR);
1677   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1678   TestSSLServer server(&acceptCallback);
1679
1680   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1681   socket->open();
1682   uint8_t buf[3] = {0x16, 0x03, 0x01};
1683   socket->write(buf, sizeof(buf));
1684   socket->close();
1685
1686   handshakeCallback.waitForHandshake();
1687   EXPECT_NE(
1688       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1689   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1690 }
1691
1692 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1693   // Start listening on a local port
1694   WriteCallbackBase writeCallback;
1695   WriteErrorCallback readCallback(&writeCallback);
1696   HandshakeCallback handshakeCallback(&readCallback,
1697                                       HandshakeCallback::EXPECT_ERROR);
1698   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1699   TestSSLServer server(&acceptCallback);
1700
1701   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1702   socket->open();
1703   uint8_t buf[256] = {0x16, 0x03};
1704   memset(buf + 2, 'a', sizeof(buf) - 2);
1705   socket->write(buf, sizeof(buf));
1706   socket->close();
1707
1708   handshakeCallback.waitForHandshake();
1709   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1710             std::string::npos);
1711 #if defined(OPENSSL_IS_BORINGSSL)
1712   EXPECT_NE(
1713       handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
1714       std::string::npos);
1715 #else
1716   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1717             std::string::npos);
1718 #endif
1719 }
1720
1721 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
1722   using folly::ssl::OpenSSLUtils;
1723   EXPECT_EQ(
1724       OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
1725   // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
1726   EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
1727   // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
1728   EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
1729 }
1730
1731 #if FOLLY_ALLOW_TFO
1732
1733 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1734  public:
1735   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1736
1737   explicit MockAsyncTFOSSLSocket(
1738       std::shared_ptr<folly::SSLContext> sslCtx,
1739       EventBase* evb)
1740       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1741
1742   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1743 };
1744
1745 /**
1746  * Test connecting to, writing to, reading from, and closing the
1747  * connection to the SSL server with TFO.
1748  */
1749 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1750   // Start listening on a local port
1751   WriteCallbackBase writeCallback;
1752   ReadCallback readCallback(&writeCallback);
1753   HandshakeCallback handshakeCallback(&readCallback);
1754   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1755   TestSSLServer server(&acceptCallback, true);
1756
1757   // Set up SSL context.
1758   auto sslContext = std::make_shared<SSLContext>();
1759
1760   // connect
1761   auto socket =
1762       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1763   socket->enableTFO();
1764   socket->open();
1765
1766   // write()
1767   std::array<uint8_t, 128> buf;
1768   memset(buf.data(), 'a', buf.size());
1769   socket->write(buf.data(), buf.size());
1770
1771   // read()
1772   std::array<uint8_t, 128> readbuf;
1773   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1774   EXPECT_EQ(bytesRead, 128);
1775   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1776
1777   // close()
1778   socket->close();
1779 }
1780
1781 /**
1782  * Test connecting to, writing to, reading from, and closing the
1783  * connection to the SSL server with TFO.
1784  */
1785 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1786   // Start listening on a local port
1787   WriteCallbackBase writeCallback;
1788   ReadCallback readCallback(&writeCallback);
1789   HandshakeCallback handshakeCallback(&readCallback);
1790   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1791   TestSSLServer server(&acceptCallback, false);
1792
1793   // Set up SSL context.
1794   auto sslContext = std::make_shared<SSLContext>();
1795
1796   // connect
1797   auto socket =
1798       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1799   socket->enableTFO();
1800   socket->open();
1801
1802   // write()
1803   std::array<uint8_t, 128> buf;
1804   memset(buf.data(), 'a', buf.size());
1805   socket->write(buf.data(), buf.size());
1806
1807   // read()
1808   std::array<uint8_t, 128> readbuf;
1809   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1810   EXPECT_EQ(bytesRead, 128);
1811   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1812
1813   // close()
1814   socket->close();
1815 }
1816
1817 class ConnCallback : public AsyncSocket::ConnectCallback {
1818  public:
1819   virtual void connectSuccess() noexcept override {
1820     state = State::SUCCESS;
1821   }
1822
1823   virtual void connectErr(const AsyncSocketException& ex) noexcept override {
1824     state = State::ERROR;
1825     error = ex.what();
1826   }
1827
1828   enum class State { WAITING, SUCCESS, ERROR };
1829
1830   State state{State::WAITING};
1831   std::string error;
1832 };
1833
1834 template <class Cardinality>
1835 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1836     EventBase* evb,
1837     const SocketAddress& address,
1838     Cardinality cardinality) {
1839   // Set up SSL context.
1840   auto sslContext = std::make_shared<SSLContext>();
1841
1842   // connect
1843   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1844       new MockAsyncTFOSSLSocket(sslContext, evb));
1845   socket->enableTFO();
1846
1847   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1848       .Times(cardinality)
1849       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1850         sockaddr_storage addr;
1851         auto len = address.getAddress(&addr);
1852         return connect(fd, (const struct sockaddr*)&addr, len);
1853       }));
1854   return socket;
1855 }
1856
1857 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1858   // Start listening on a local port
1859   WriteCallbackBase writeCallback;
1860   ReadCallback readCallback(&writeCallback);
1861   HandshakeCallback handshakeCallback(&readCallback);
1862   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1863   TestSSLServer server(&acceptCallback, true);
1864
1865   EventBase evb;
1866
1867   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1868   ConnCallback ccb;
1869   socket->connect(&ccb, server.getAddress(), 30);
1870
1871   evb.loop();
1872   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1873
1874   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1875   evb.loop();
1876
1877   BlockingSocket sock(std::move(socket));
1878   // write()
1879   std::array<uint8_t, 128> buf;
1880   memset(buf.data(), 'a', buf.size());
1881   sock.write(buf.data(), buf.size());
1882
1883   // read()
1884   std::array<uint8_t, 128> readbuf;
1885   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1886   EXPECT_EQ(bytesRead, 128);
1887   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1888
1889   // close()
1890   sock.close();
1891 }
1892
1893 #if !defined(OPENSSL_IS_BORINGSSL)
1894 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1895   // Start listening on a local port
1896   ConnectTimeoutCallback acceptCallback;
1897   TestSSLServer server(&acceptCallback, true);
1898
1899   // Set up SSL context.
1900   auto sslContext = std::make_shared<SSLContext>();
1901
1902   // connect
1903   auto socket =
1904       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1905   socket->enableTFO();
1906   EXPECT_THROW(
1907       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
1908 }
1909 #endif
1910
1911 #if !defined(OPENSSL_IS_BORINGSSL)
1912 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1913   // Start listening on a local port
1914   ConnectTimeoutCallback acceptCallback;
1915   TestSSLServer server(&acceptCallback, true);
1916
1917   EventBase evb;
1918
1919   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1920   ConnCallback ccb;
1921   // Set a short timeout
1922   socket->connect(&ccb, server.getAddress(), 1);
1923
1924   evb.loop();
1925   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1926 }
1927 #endif
1928
1929 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
1930   // Start listening on a local port
1931   EmptyReadCallback readCallback;
1932   HandshakeCallback handshakeCallback(
1933       &readCallback, HandshakeCallback::EXPECT_ERROR);
1934   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
1935   TestSSLServer server(&acceptCallback, true);
1936
1937   EventBase evb;
1938
1939   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1940   ConnCallback ccb;
1941   socket->connect(&ccb, server.getAddress(), 100);
1942
1943   evb.loop();
1944   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1945   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
1946 }
1947
1948 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
1949   // Start listening on a local port
1950   EventBase evb;
1951
1952   // Hopefully nothing is listening on this address
1953   SocketAddress addr("127.0.0.1", 65535);
1954   auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
1955   ConnCallback ccb;
1956   socket->connect(&ccb, addr, 100);
1957
1958   evb.loop();
1959   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1960   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
1961 }
1962
1963 #endif
1964
1965 } // namespace
1966
1967 #ifdef SIGPIPE
1968 ///////////////////////////////////////////////////////////////////////////
1969 // init_unit_test_suite
1970 ///////////////////////////////////////////////////////////////////////////
1971 namespace {
1972 struct Initializer {
1973   Initializer() {
1974     signal(SIGPIPE, SIG_IGN);
1975   }
1976 };
1977 Initializer initializer;
1978 } // anonymous
1979 #endif