Add handshake and connect times
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/SocketAddress.h>
24
25 #include <folly/io/async/test/BlockingSocket.h>
26
27 #include <gtest/gtest.h>
28 #include <iostream>
29 #include <list>
30 #include <set>
31 #include <unistd.h>
32 #include <fcntl.h>
33 #include <poll.h>
34 #include <sys/types.h>
35 #include <sys/socket.h>
36 #include <netinet/tcp.h>
37 #include <folly/io/Cursor.h>
38
39 using std::string;
40 using std::vector;
41 using std::min;
42 using std::cerr;
43 using std::endl;
44 using std::list;
45
46 namespace folly {
47 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
48 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
49 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
50
51 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
52 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
53 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
54
55 constexpr size_t SSLClient::kMaxReadBufferSz;
56 constexpr size_t SSLClient::kMaxReadsPerEvent;
57
58 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) :
59 ctx_(new folly::SSLContext),
60     acb_(acb),
61   socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
62   // Set up the SSL context
63   ctx_->loadCertificate(testCert);
64   ctx_->loadPrivateKey(testKey);
65   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
66
67   acb_->ctx_ = ctx_;
68   acb_->base_ = &evb_;
69
70   //set up the listening socket
71   socket_->bind(0);
72   socket_->getAddress(&address_);
73   socket_->listen(100);
74   socket_->addAcceptCallback(acb_, &evb_);
75   socket_->startAccepting();
76
77   int ret = pthread_create(&thread_, nullptr, Main, this);
78   assert(ret == 0);
79   (void)ret;
80
81   std::cerr << "Accepting connections on " << address_ << std::endl;
82 }
83
84 void getfds(int fds[2]) {
85   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
86     FAIL() << "failed to create socketpair: " << strerror(errno);
87   }
88   for (int idx = 0; idx < 2; ++idx) {
89     int flags = fcntl(fds[idx], F_GETFL, 0);
90     if (flags == -1) {
91       FAIL() << "failed to get flags for socket " << idx << ": "
92              << strerror(errno);
93     }
94     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
95       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
96              << strerror(errno);
97     }
98   }
99 }
100
101 void getctx(
102   std::shared_ptr<folly::SSLContext> clientCtx,
103   std::shared_ptr<folly::SSLContext> serverCtx) {
104   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
105
106   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
107   serverCtx->loadCertificate(
108       testCert);
109   serverCtx->loadPrivateKey(
110       testKey);
111 }
112
113 void sslsocketpair(
114   EventBase* eventBase,
115   AsyncSSLSocket::UniquePtr* clientSock,
116   AsyncSSLSocket::UniquePtr* serverSock) {
117   auto clientCtx = std::make_shared<folly::SSLContext>();
118   auto serverCtx = std::make_shared<folly::SSLContext>();
119   int fds[2];
120   getfds(fds);
121   getctx(clientCtx, serverCtx);
122   clientSock->reset(new AsyncSSLSocket(
123                       clientCtx, eventBase, fds[0], false));
124   serverSock->reset(new AsyncSSLSocket(
125                       serverCtx, eventBase, fds[1], true));
126
127   // (*clientSock)->setSendTimeout(100);
128   // (*serverSock)->setSendTimeout(100);
129 }
130
131 // client protocol filters
132 bool clientProtoFilterPickPony(unsigned char** client,
133   unsigned int* client_len, const unsigned char*, unsigned int ) {
134   //the protocol string in length prefixed byte string. the
135   //length byte is not included in the length
136   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
137   *client = p;
138   *client_len = 7;
139   return true;
140 }
141
142 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
143   const unsigned char*, unsigned int) {
144   return false;
145 }
146
147 /**
148  * Test connecting to, writing to, reading from, and closing the
149  * connection to the SSL server.
150  */
151 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
152   // Start listening on a local port
153   WriteCallbackBase writeCallback;
154   ReadCallback readCallback(&writeCallback);
155   HandshakeCallback handshakeCallback(&readCallback);
156   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
157   TestSSLServer server(&acceptCallback);
158
159   // Set up SSL context.
160   std::shared_ptr<SSLContext> sslContext(new SSLContext());
161   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
162   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
163   //sslContext->authenticate(true, false);
164
165   // connect
166   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
167                                                  sslContext);
168   socket->open();
169
170   // write()
171   uint8_t buf[128];
172   memset(buf, 'a', sizeof(buf));
173   socket->write(buf, sizeof(buf));
174
175   // read()
176   uint8_t readbuf[128];
177   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
178   EXPECT_EQ(bytesRead, 128);
179   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
180
181   // close()
182   socket->close();
183
184   cerr << "ConnectWriteReadClose test completed" << endl;
185 }
186
187 /**
188  * Negative test for handshakeError().
189  */
190 TEST(AsyncSSLSocketTest, HandshakeError) {
191   // Start listening on a local port
192   WriteCallbackBase writeCallback;
193   ReadCallback readCallback(&writeCallback);
194   HandshakeCallback handshakeCallback(&readCallback);
195   HandshakeErrorCallback acceptCallback(&handshakeCallback);
196   TestSSLServer server(&acceptCallback);
197
198   // Set up SSL context.
199   std::shared_ptr<SSLContext> sslContext(new SSLContext());
200   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
201
202   // connect
203   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
204                                                  sslContext);
205   // read()
206   bool ex = false;
207   try {
208     socket->open();
209
210     uint8_t readbuf[128];
211     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
212     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
213   } catch (AsyncSocketException &e) {
214     ex = true;
215   }
216   EXPECT_TRUE(ex);
217
218   // close()
219   socket->close();
220   cerr << "HandshakeError test completed" << endl;
221 }
222
223 /**
224  * Negative test for readError().
225  */
226 TEST(AsyncSSLSocketTest, ReadError) {
227   // Start listening on a local port
228   WriteCallbackBase writeCallback;
229   ReadErrorCallback readCallback(&writeCallback);
230   HandshakeCallback handshakeCallback(&readCallback);
231   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
232   TestSSLServer server(&acceptCallback);
233
234   // Set up SSL context.
235   std::shared_ptr<SSLContext> sslContext(new SSLContext());
236   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
237
238   // connect
239   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
240                                                  sslContext);
241   socket->open();
242
243   // write something to trigger ssl handshake
244   uint8_t buf[128];
245   memset(buf, 'a', sizeof(buf));
246   socket->write(buf, sizeof(buf));
247
248   socket->close();
249   cerr << "ReadError test completed" << endl;
250 }
251
252 /**
253  * Negative test for writeError().
254  */
255 TEST(AsyncSSLSocketTest, WriteError) {
256   // Start listening on a local port
257   WriteCallbackBase writeCallback;
258   WriteErrorCallback readCallback(&writeCallback);
259   HandshakeCallback handshakeCallback(&readCallback);
260   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
261   TestSSLServer server(&acceptCallback);
262
263   // Set up SSL context.
264   std::shared_ptr<SSLContext> sslContext(new SSLContext());
265   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
266
267   // connect
268   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
269                                                  sslContext);
270   socket->open();
271
272   // write something to trigger ssl handshake
273   uint8_t buf[128];
274   memset(buf, 'a', sizeof(buf));
275   socket->write(buf, sizeof(buf));
276
277   socket->close();
278   cerr << "WriteError test completed" << endl;
279 }
280
281 /**
282  * Test a socket with TCP_NODELAY unset.
283  */
284 TEST(AsyncSSLSocketTest, SocketWithDelay) {
285   // Start listening on a local port
286   WriteCallbackBase writeCallback;
287   ReadCallback readCallback(&writeCallback);
288   HandshakeCallback handshakeCallback(&readCallback);
289   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
290   TestSSLServer server(&acceptCallback);
291
292   // Set up SSL context.
293   std::shared_ptr<SSLContext> sslContext(new SSLContext());
294   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
295
296   // connect
297   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
298                                                  sslContext);
299   socket->open();
300
301   // write()
302   uint8_t buf[128];
303   memset(buf, 'a', sizeof(buf));
304   socket->write(buf, sizeof(buf));
305
306   // read()
307   uint8_t readbuf[128];
308   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
309   EXPECT_EQ(bytesRead, 128);
310   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
311
312   // close()
313   socket->close();
314
315   cerr << "SocketWithDelay test completed" << endl;
316 }
317
318 TEST(AsyncSSLSocketTest, NpnTestOverlap) {
319   EventBase eventBase;
320   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
321   std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
322   int fds[2];
323   getfds(fds);
324   getctx(clientCtx, serverCtx);
325
326   clientCtx->setAdvertisedNextProtocols({"blub","baz"});
327   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
328
329   AsyncSSLSocket::UniquePtr clientSock(
330     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
331   AsyncSSLSocket::UniquePtr serverSock(
332     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
333   NpnClient client(std::move(clientSock));
334   NpnServer server(std::move(serverSock));
335
336   eventBase.loop();
337
338   EXPECT_TRUE(client.nextProtoLength != 0);
339   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
340   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
341                            server.nextProtoLength), 0);
342   string selected((const char*)client.nextProto, client.nextProtoLength);
343   EXPECT_EQ(selected.compare("baz"), 0);
344 }
345
346 TEST(AsyncSSLSocketTest, NpnTestUnset) {
347   // Identical to above test, except that we want unset NPN before
348   // looping.
349   EventBase eventBase;
350   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
351   std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
352   int fds[2];
353   getfds(fds);
354   getctx(clientCtx, serverCtx);
355
356   clientCtx->setAdvertisedNextProtocols({"blub","baz"});
357   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
358
359   AsyncSSLSocket::UniquePtr clientSock(
360     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
361   AsyncSSLSocket::UniquePtr serverSock(
362     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
363
364   // unsetting NPN for any of [client, server] is enought to make NPN not
365   // work
366   clientCtx->unsetNextProtocols();
367
368   NpnClient client(std::move(clientSock));
369   NpnServer server(std::move(serverSock));
370
371   eventBase.loop();
372
373   EXPECT_TRUE(client.nextProtoLength == 0);
374   EXPECT_TRUE(server.nextProtoLength == 0);
375   EXPECT_TRUE(client.nextProto == nullptr);
376   EXPECT_TRUE(server.nextProto == nullptr);
377 }
378
379 TEST(AsyncSSLSocketTest, NpnTestNoOverlap) {
380   EventBase eventBase;
381   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
382   std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
383   int fds[2];
384   getfds(fds);
385   getctx(clientCtx, serverCtx);
386
387   clientCtx->setAdvertisedNextProtocols({"blub"});
388   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
389
390   AsyncSSLSocket::UniquePtr clientSock(
391     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
392   AsyncSSLSocket::UniquePtr serverSock(
393     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
394   NpnClient client(std::move(clientSock));
395   NpnServer server(std::move(serverSock));
396
397   eventBase.loop();
398
399   EXPECT_TRUE(client.nextProtoLength != 0);
400   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
401   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
402                            server.nextProtoLength), 0);
403   string selected((const char*)client.nextProto, client.nextProtoLength);
404   EXPECT_EQ(selected.compare("blub"), 0);
405 }
406
407 TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterHit) {
408   EventBase eventBase;
409   auto clientCtx = std::make_shared<SSLContext>();
410   auto serverCtx = std::make_shared<SSLContext>();
411   int fds[2];
412   getfds(fds);
413   getctx(clientCtx, serverCtx);
414
415   clientCtx->setAdvertisedNextProtocols({"blub"});
416   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
417   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
418
419   AsyncSSLSocket::UniquePtr clientSock(
420     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
421   AsyncSSLSocket::UniquePtr serverSock(
422     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
423   NpnClient client(std::move(clientSock));
424   NpnServer server(std::move(serverSock));
425
426   eventBase.loop();
427
428   EXPECT_TRUE(client.nextProtoLength != 0);
429   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
430   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
431                            server.nextProtoLength), 0);
432   string selected((const char*)client.nextProto, client.nextProtoLength);
433   EXPECT_EQ(selected.compare("ponies"), 0);
434 }
435
436 TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterMiss) {
437   EventBase eventBase;
438   auto clientCtx = std::make_shared<SSLContext>();
439   auto serverCtx = std::make_shared<SSLContext>();
440   int fds[2];
441   getfds(fds);
442   getctx(clientCtx, serverCtx);
443
444   clientCtx->setAdvertisedNextProtocols({"blub"});
445   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
446   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
447
448   AsyncSSLSocket::UniquePtr clientSock(
449     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
450   AsyncSSLSocket::UniquePtr serverSock(
451     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
452   NpnClient client(std::move(clientSock));
453   NpnServer server(std::move(serverSock));
454
455   eventBase.loop();
456
457   EXPECT_TRUE(client.nextProtoLength != 0);
458   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
459   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
460                            server.nextProtoLength), 0);
461   string selected((const char*)client.nextProto, client.nextProtoLength);
462   EXPECT_EQ(selected.compare("blub"), 0);
463 }
464
465 TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
466   // Probability that this test will fail is 2^-64, which could be considered
467   // as negligible.
468   const int kTries = 64;
469
470   std::set<string> selectedProtocols;
471   for (int i = 0; i < kTries; ++i) {
472     EventBase eventBase;
473     std::shared_ptr<SSLContext> clientCtx = std::make_shared<SSLContext>();
474     std::shared_ptr<SSLContext> serverCtx = std::make_shared<SSLContext>();
475     int fds[2];
476     getfds(fds);
477     getctx(clientCtx, serverCtx);
478
479     clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
480     serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}},
481         {1, {"bar"}}});
482
483
484     AsyncSSLSocket::UniquePtr clientSock(
485       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
486     AsyncSSLSocket::UniquePtr serverSock(
487       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
488     NpnClient client(std::move(clientSock));
489     NpnServer server(std::move(serverSock));
490
491     eventBase.loop();
492
493     EXPECT_TRUE(client.nextProtoLength != 0);
494     EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
495     EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
496                              server.nextProtoLength), 0);
497     string selected((const char*)client.nextProto, client.nextProtoLength);
498     selectedProtocols.insert(selected);
499   }
500   EXPECT_EQ(selectedProtocols.size(), 2);
501 }
502
503
504 #ifndef OPENSSL_NO_TLSEXT
505 /**
506  * 1. Client sends TLSEXT_HOSTNAME in client hello.
507  * 2. Server found a match SSL_CTX and use this SSL_CTX to
508  *    continue the SSL handshake.
509  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
510  */
511 TEST(AsyncSSLSocketTest, SNITestMatch) {
512   EventBase eventBase;
513   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
514   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
515   // Use the same SSLContext to continue the handshake after
516   // tlsext_hostname match.
517   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
518   const std::string serverName("xyz.newdev.facebook.com");
519   int fds[2];
520   getfds(fds);
521   getctx(clientCtx, dfServerCtx);
522
523   AsyncSSLSocket::UniquePtr clientSock(
524     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
525   AsyncSSLSocket::UniquePtr serverSock(
526     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
527   SNIClient client(std::move(clientSock));
528   SNIServer server(std::move(serverSock),
529                    dfServerCtx,
530                    hskServerCtx,
531                    serverName);
532
533   eventBase.loop();
534
535   EXPECT_TRUE(client.serverNameMatch);
536   EXPECT_TRUE(server.serverNameMatch);
537 }
538
539 /**
540  * 1. Client sends TLSEXT_HOSTNAME in client hello.
541  * 2. Server cannot find a matching SSL_CTX and continue to use
542  *    the current SSL_CTX to do the handshake.
543  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
544  */
545 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
546   EventBase eventBase;
547   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
548   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
549   // Use the same SSLContext to continue the handshake after
550   // tlsext_hostname match.
551   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
552   const std::string clientRequestingServerName("foo.com");
553   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
554
555   int fds[2];
556   getfds(fds);
557   getctx(clientCtx, dfServerCtx);
558
559   AsyncSSLSocket::UniquePtr clientSock(
560     new AsyncSSLSocket(clientCtx,
561                         &eventBase,
562                         fds[0],
563                         clientRequestingServerName));
564   AsyncSSLSocket::UniquePtr serverSock(
565     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
566   SNIClient client(std::move(clientSock));
567   SNIServer server(std::move(serverSock),
568                    dfServerCtx,
569                    hskServerCtx,
570                    serverExpectedServerName);
571
572   eventBase.loop();
573
574   EXPECT_TRUE(!client.serverNameMatch);
575   EXPECT_TRUE(!server.serverNameMatch);
576 }
577 /**
578  * 1. Client sends TLSEXT_HOSTNAME in client hello.
579  * 2. We then change the serverName.
580  * 3. We expect that we get 'false' as the result for serNameMatch.
581  */
582
583 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
584    EventBase eventBase;
585   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
586   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
587   // Use the same SSLContext to continue the handshake after
588   // tlsext_hostname match.
589   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
590   const std::string serverName("xyz.newdev.facebook.com");
591   int fds[2];
592   getfds(fds);
593   getctx(clientCtx, dfServerCtx);
594
595   AsyncSSLSocket::UniquePtr clientSock(
596     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
597   //Change the server name
598   std::string newName("new.com");
599   clientSock->setServerName(newName);
600   AsyncSSLSocket::UniquePtr serverSock(
601     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
602   SNIClient client(std::move(clientSock));
603   SNIServer server(std::move(serverSock),
604                    dfServerCtx,
605                    hskServerCtx,
606                    serverName);
607
608   eventBase.loop();
609
610   EXPECT_TRUE(!client.serverNameMatch);
611 }
612
613 /**
614  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
615  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
616  */
617 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
618   EventBase eventBase;
619   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
620   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
621   // Use the same SSLContext to continue the handshake after
622   // tlsext_hostname match.
623   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
624   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
625
626   int fds[2];
627   getfds(fds);
628   getctx(clientCtx, dfServerCtx);
629
630   AsyncSSLSocket::UniquePtr clientSock(
631     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
632   AsyncSSLSocket::UniquePtr serverSock(
633     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
634   SNIClient client(std::move(clientSock));
635   SNIServer server(std::move(serverSock),
636                    dfServerCtx,
637                    hskServerCtx,
638                    serverExpectedServerName);
639
640   eventBase.loop();
641
642   EXPECT_TRUE(!client.serverNameMatch);
643   EXPECT_TRUE(!server.serverNameMatch);
644 }
645
646 #endif
647 /**
648  * Test SSL client socket
649  */
650 TEST(AsyncSSLSocketTest, SSLClientTest) {
651   // Start listening on a local port
652   WriteCallbackBase writeCallback;
653   ReadCallback readCallback(&writeCallback);
654   HandshakeCallback handshakeCallback(&readCallback);
655   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
656   TestSSLServer server(&acceptCallback);
657
658   // Set up SSL client
659   EventBase eventBase;
660   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
661                                              1));
662
663   client->connect();
664   EventBaseAborter eba(&eventBase, 3000);
665   eventBase.loop();
666
667   EXPECT_EQ(client->getMiss(), 1);
668   EXPECT_EQ(client->getHit(), 0);
669
670   cerr << "SSLClientTest test completed" << endl;
671 }
672
673
674 /**
675  * Test SSL client socket session re-use
676  */
677 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
678   // Start listening on a local port
679   WriteCallbackBase writeCallback;
680   ReadCallback readCallback(&writeCallback);
681   HandshakeCallback handshakeCallback(&readCallback);
682   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
683   TestSSLServer server(&acceptCallback);
684
685   // Set up SSL client
686   EventBase eventBase;
687   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
688                                              10));
689
690   client->connect();
691   EventBaseAborter eba(&eventBase, 3000);
692   eventBase.loop();
693
694   EXPECT_EQ(client->getMiss(), 1);
695   EXPECT_EQ(client->getHit(), 9);
696
697   cerr << "SSLClientTestReuse test completed" << endl;
698 }
699
700 /**
701  * Test SSL client socket timeout
702  */
703 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
704   // Start listening on a local port
705   EmptyReadCallback readCallback;
706   HandshakeCallback handshakeCallback(&readCallback,
707                                       HandshakeCallback::EXPECT_ERROR);
708   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
709   TestSSLServer server(&acceptCallback);
710
711   // Set up SSL client
712   EventBase eventBase;
713   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
714                                              1, 10));
715   client->connect(true /* write before connect completes */);
716   EventBaseAborter eba(&eventBase, 3000);
717   eventBase.loop();
718
719   usleep(100000);
720   // This is checking that the connectError callback precedes any queued
721   // writeError callbacks.  This matches AsyncSocket's behavior
722   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
723   EXPECT_EQ(client->getErrors(), 1);
724   EXPECT_EQ(client->getMiss(), 0);
725   EXPECT_EQ(client->getHit(), 0);
726
727   cerr << "SSLClientTimeoutTest test completed" << endl;
728 }
729
730
731 /**
732  * Test SSL server async cache
733  */
734 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
735   // Start listening on a local port
736   WriteCallbackBase writeCallback;
737   ReadCallback readCallback(&writeCallback);
738   HandshakeCallback handshakeCallback(&readCallback);
739   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
740   TestSSLAsyncCacheServer server(&acceptCallback);
741
742   // Set up SSL client
743   EventBase eventBase;
744   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
745                                              10, 500));
746
747   client->connect();
748   EventBaseAborter eba(&eventBase, 3000);
749   eventBase.loop();
750
751   EXPECT_EQ(server.getAsyncCallbacks(), 18);
752   EXPECT_EQ(server.getAsyncLookups(), 9);
753   EXPECT_EQ(client->getMiss(), 10);
754   EXPECT_EQ(client->getHit(), 0);
755
756   cerr << "SSLServerAsyncCacheTest test completed" << endl;
757 }
758
759
760 /**
761  * Test SSL server accept timeout with cache path
762  */
763 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
764   // Start listening on a local port
765   WriteCallbackBase writeCallback;
766   ReadCallback readCallback(&writeCallback);
767   EmptyReadCallback clientReadCallback;
768   HandshakeCallback handshakeCallback(&readCallback);
769   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
770   TestSSLAsyncCacheServer server(&acceptCallback);
771
772   // Set up SSL client
773   EventBase eventBase;
774   // only do a TCP connect
775   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
776   sock->connect(nullptr, server.getAddress());
777   clientReadCallback.tcpSocket_ = sock;
778   sock->setReadCB(&clientReadCallback);
779
780   EventBaseAborter eba(&eventBase, 3000);
781   eventBase.loop();
782
783   EXPECT_EQ(readCallback.state, STATE_WAITING);
784
785   cerr << "SSLServerTimeoutTest test completed" << endl;
786 }
787
788 /**
789  * Test SSL server accept timeout with cache path
790  */
791 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
792   // Start listening on a local port
793   WriteCallbackBase writeCallback;
794   ReadCallback readCallback(&writeCallback);
795   HandshakeCallback handshakeCallback(&readCallback);
796   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
797   TestSSLAsyncCacheServer server(&acceptCallback);
798
799   // Set up SSL client
800   EventBase eventBase;
801   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
802                                              2));
803
804   client->connect();
805   EventBaseAborter eba(&eventBase, 3000);
806   eventBase.loop();
807
808   EXPECT_EQ(server.getAsyncCallbacks(), 1);
809   EXPECT_EQ(server.getAsyncLookups(), 1);
810   EXPECT_EQ(client->getErrors(), 1);
811   EXPECT_EQ(client->getMiss(), 1);
812   EXPECT_EQ(client->getHit(), 0);
813
814   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
815 }
816
817 /**
818  * Test SSL server accept timeout with cache path
819  */
820 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
821   // Start listening on a local port
822   WriteCallbackBase writeCallback;
823   ReadCallback readCallback(&writeCallback);
824   HandshakeCallback handshakeCallback(&readCallback,
825                                       HandshakeCallback::EXPECT_ERROR);
826   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
827   TestSSLAsyncCacheServer server(&acceptCallback, 500);
828
829   // Set up SSL client
830   EventBase eventBase;
831   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
832                                              2, 100));
833
834   client->connect();
835   EventBaseAborter eba(&eventBase, 3000);
836   eventBase.loop();
837
838   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
839       handshakeCallback.closeSocket();});
840   // give time for the cache lookup to come back and find it closed
841   usleep(500000);
842
843   EXPECT_EQ(server.getAsyncCallbacks(), 1);
844   EXPECT_EQ(server.getAsyncLookups(), 1);
845   EXPECT_EQ(client->getErrors(), 1);
846   EXPECT_EQ(client->getMiss(), 1);
847   EXPECT_EQ(client->getHit(), 0);
848
849   cerr << "SSLServerCacheCloseTest test completed" << endl;
850 }
851
852 /**
853  * Verify Client Ciphers obtained using SSL MSG Callback.
854  */
855 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
856   EventBase eventBase;
857   auto clientCtx = std::make_shared<SSLContext>();
858   auto serverCtx = std::make_shared<SSLContext>();
859   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
860   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
861   serverCtx->loadPrivateKey(testKey);
862   serverCtx->loadCertificate(testCert);
863   serverCtx->loadTrustedCertificates(testCA);
864   serverCtx->loadClientCAList(testCA);
865
866   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
867   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
868   clientCtx->loadPrivateKey(testKey);
869   clientCtx->loadCertificate(testCert);
870   clientCtx->loadTrustedCertificates(testCA);
871
872   int fds[2];
873   getfds(fds);
874
875   AsyncSSLSocket::UniquePtr clientSock(
876       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
877   AsyncSSLSocket::UniquePtr serverSock(
878       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
879
880   SSLHandshakeClient client(std::move(clientSock), true, true);
881   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
882
883   eventBase.loop();
884
885   EXPECT_EQ(server.clientCiphers_,
886             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
887   EXPECT_TRUE(client.handshakeVerify_);
888   EXPECT_TRUE(client.handshakeSuccess_);
889   EXPECT_TRUE(!client.handshakeError_);
890   EXPECT_TRUE(server.handshakeVerify_);
891   EXPECT_TRUE(server.handshakeSuccess_);
892   EXPECT_TRUE(!server.handshakeError_);
893 }
894
895 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
896   EventBase eventBase;
897   auto ctx = std::make_shared<SSLContext>();
898
899   int fds[2];
900   getfds(fds);
901
902   int bufLen = 42;
903   uint8_t majorVersion = 18;
904   uint8_t minorVersion = 25;
905
906   // Create callback buf
907   auto buf = IOBuf::create(bufLen);
908   buf->append(bufLen);
909   folly::io::RWPrivateCursor cursor(buf.get());
910   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
911   cursor.write<uint16_t>(0);
912   cursor.write<uint8_t>(38);
913   cursor.write<uint8_t>(majorVersion);
914   cursor.write<uint8_t>(minorVersion);
915   cursor.skip(32);
916   cursor.write<uint32_t>(0);
917
918   SSL* ssl = ctx->createSSL();
919   SCOPE_EXIT { SSL_free(ssl); };
920   AsyncSSLSocket::UniquePtr sock(
921       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
922   sock->enableClientHelloParsing();
923
924   // Test client hello parsing in one packet
925   AsyncSSLSocket::clientHelloParsingCallback(
926       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
927   buf.reset();
928
929   auto parsedClientHello = sock->getClientHelloInfo();
930   EXPECT_TRUE(parsedClientHello != nullptr);
931   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
932   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
933 }
934
935 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
936   EventBase eventBase;
937   auto ctx = std::make_shared<SSLContext>();
938
939   int fds[2];
940   getfds(fds);
941
942   int bufLen = 42;
943   uint8_t majorVersion = 18;
944   uint8_t minorVersion = 25;
945
946   // Create callback buf
947   auto buf = IOBuf::create(bufLen);
948   buf->append(bufLen);
949   folly::io::RWPrivateCursor cursor(buf.get());
950   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
951   cursor.write<uint16_t>(0);
952   cursor.write<uint8_t>(38);
953   cursor.write<uint8_t>(majorVersion);
954   cursor.write<uint8_t>(minorVersion);
955   cursor.skip(32);
956   cursor.write<uint32_t>(0);
957
958   SSL* ssl = ctx->createSSL();
959   SCOPE_EXIT { SSL_free(ssl); };
960   AsyncSSLSocket::UniquePtr sock(
961       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
962   sock->enableClientHelloParsing();
963
964   // Test parsing with two packets with first packet size < 3
965   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
966   AsyncSSLSocket::clientHelloParsingCallback(
967       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
968       ssl, sock.get());
969   bufCopy.reset();
970   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
971   AsyncSSLSocket::clientHelloParsingCallback(
972       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
973       ssl, sock.get());
974   bufCopy.reset();
975
976   auto parsedClientHello = sock->getClientHelloInfo();
977   EXPECT_TRUE(parsedClientHello != nullptr);
978   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
979   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
980 }
981
982 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
983   EventBase eventBase;
984   auto ctx = std::make_shared<SSLContext>();
985
986   int fds[2];
987   getfds(fds);
988
989   int bufLen = 42;
990   uint8_t majorVersion = 18;
991   uint8_t minorVersion = 25;
992
993   // Create callback buf
994   auto buf = IOBuf::create(bufLen);
995   buf->append(bufLen);
996   folly::io::RWPrivateCursor cursor(buf.get());
997   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
998   cursor.write<uint16_t>(0);
999   cursor.write<uint8_t>(38);
1000   cursor.write<uint8_t>(majorVersion);
1001   cursor.write<uint8_t>(minorVersion);
1002   cursor.skip(32);
1003   cursor.write<uint32_t>(0);
1004
1005   SSL* ssl = ctx->createSSL();
1006   SCOPE_EXIT { SSL_free(ssl); };
1007   AsyncSSLSocket::UniquePtr sock(
1008       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1009   sock->enableClientHelloParsing();
1010
1011   // Test parsing with multiple small packets
1012   for (uint64_t i = 0; i < buf->length(); i += 3) {
1013     auto bufCopy = folly::IOBuf::copyBuffer(
1014         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1015     AsyncSSLSocket::clientHelloParsingCallback(
1016         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1017         ssl, sock.get());
1018     bufCopy.reset();
1019   }
1020
1021   auto parsedClientHello = sock->getClientHelloInfo();
1022   EXPECT_TRUE(parsedClientHello != nullptr);
1023   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1024   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1025 }
1026
1027 /**
1028  * Verify sucessful behavior of SSL certificate validation.
1029  */
1030 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1031   EventBase eventBase;
1032   auto clientCtx = std::make_shared<SSLContext>();
1033   auto dfServerCtx = std::make_shared<SSLContext>();
1034
1035   int fds[2];
1036   getfds(fds);
1037   getctx(clientCtx, dfServerCtx);
1038
1039   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1040   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1041
1042   AsyncSSLSocket::UniquePtr clientSock(
1043     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1044   AsyncSSLSocket::UniquePtr serverSock(
1045     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1046
1047   SSLHandshakeClient client(std::move(clientSock), true, true);
1048   clientCtx->loadTrustedCertificates(testCA);
1049
1050   SSLHandshakeServer server(std::move(serverSock), true, true);
1051
1052   eventBase.loop();
1053
1054   EXPECT_TRUE(client.handshakeVerify_);
1055   EXPECT_TRUE(client.handshakeSuccess_);
1056   EXPECT_TRUE(!client.handshakeError_);
1057   EXPECT_LE(0, client.handshakeTime.count());
1058   EXPECT_TRUE(!server.handshakeVerify_);
1059   EXPECT_TRUE(server.handshakeSuccess_);
1060   EXPECT_TRUE(!server.handshakeError_);
1061   EXPECT_LE(0, server.handshakeTime.count());
1062 }
1063
1064 /**
1065  * Verify that the client's verification callback is able to fail SSL
1066  * connection establishment.
1067  */
1068 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1069   EventBase eventBase;
1070   auto clientCtx = std::make_shared<SSLContext>();
1071   auto dfServerCtx = std::make_shared<SSLContext>();
1072
1073   int fds[2];
1074   getfds(fds);
1075   getctx(clientCtx, dfServerCtx);
1076
1077   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1078   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1079
1080   AsyncSSLSocket::UniquePtr clientSock(
1081     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1082   AsyncSSLSocket::UniquePtr serverSock(
1083     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1084
1085   SSLHandshakeClient client(std::move(clientSock), true, false);
1086   clientCtx->loadTrustedCertificates(testCA);
1087
1088   SSLHandshakeServer server(std::move(serverSock), true, true);
1089
1090   eventBase.loop();
1091
1092   EXPECT_TRUE(client.handshakeVerify_);
1093   EXPECT_TRUE(!client.handshakeSuccess_);
1094   EXPECT_TRUE(client.handshakeError_);
1095   EXPECT_LE(0, client.handshakeTime.count());
1096   EXPECT_TRUE(!server.handshakeVerify_);
1097   EXPECT_TRUE(!server.handshakeSuccess_);
1098   EXPECT_TRUE(server.handshakeError_);
1099   EXPECT_LE(0, server.handshakeTime.count());
1100 }
1101
1102 /**
1103  * Verify that the options in SSLContext can be overridden in
1104  * sslConnect/Accept.i.e specifying that no validation should be performed
1105  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1106  * the validation callback.
1107  */
1108 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1109   EventBase eventBase;
1110   auto clientCtx = std::make_shared<SSLContext>();
1111   auto dfServerCtx = std::make_shared<SSLContext>();
1112
1113   int fds[2];
1114   getfds(fds);
1115   getctx(clientCtx, dfServerCtx);
1116
1117   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1118   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1119
1120   AsyncSSLSocket::UniquePtr clientSock(
1121     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1122   AsyncSSLSocket::UniquePtr serverSock(
1123     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1124
1125   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1126   clientCtx->loadTrustedCertificates(testCA);
1127
1128   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1129
1130   eventBase.loop();
1131
1132   EXPECT_TRUE(!client.handshakeVerify_);
1133   EXPECT_TRUE(client.handshakeSuccess_);
1134   EXPECT_TRUE(!client.handshakeError_);
1135   EXPECT_LE(0, client.handshakeTime.count());
1136   EXPECT_TRUE(!server.handshakeVerify_);
1137   EXPECT_TRUE(server.handshakeSuccess_);
1138   EXPECT_TRUE(!server.handshakeError_);
1139   EXPECT_LE(0, server.handshakeTime.count());
1140 }
1141
1142 /**
1143  * Verify that the options in SSLContext can be overridden in
1144  * sslConnect/Accept. Enable verification even if context says otherwise.
1145  * Test requireClientCert with client cert
1146  */
1147 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1148   EventBase eventBase;
1149   auto clientCtx = std::make_shared<SSLContext>();
1150   auto serverCtx = std::make_shared<SSLContext>();
1151   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1152   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1153   serverCtx->loadPrivateKey(testKey);
1154   serverCtx->loadCertificate(testCert);
1155   serverCtx->loadTrustedCertificates(testCA);
1156   serverCtx->loadClientCAList(testCA);
1157
1158   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1159   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1160   clientCtx->loadPrivateKey(testKey);
1161   clientCtx->loadCertificate(testCert);
1162   clientCtx->loadTrustedCertificates(testCA);
1163
1164   int fds[2];
1165   getfds(fds);
1166
1167   AsyncSSLSocket::UniquePtr clientSock(
1168       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1169   AsyncSSLSocket::UniquePtr serverSock(
1170       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1171
1172   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1173   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1174
1175   eventBase.loop();
1176
1177   EXPECT_TRUE(client.handshakeVerify_);
1178   EXPECT_TRUE(client.handshakeSuccess_);
1179   EXPECT_FALSE(client.handshakeError_);
1180   EXPECT_LE(0, client.handshakeTime.count());
1181   EXPECT_TRUE(server.handshakeVerify_);
1182   EXPECT_TRUE(server.handshakeSuccess_);
1183   EXPECT_FALSE(server.handshakeError_);
1184   EXPECT_LE(0, server.handshakeTime.count());
1185 }
1186
1187 /**
1188  * Verify that the client's verification callback is able to override
1189  * the preverification failure and allow a successful connection.
1190  */
1191 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1192   EventBase eventBase;
1193   auto clientCtx = std::make_shared<SSLContext>();
1194   auto dfServerCtx = std::make_shared<SSLContext>();
1195
1196   int fds[2];
1197   getfds(fds);
1198   getctx(clientCtx, dfServerCtx);
1199
1200   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1201   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1202
1203   AsyncSSLSocket::UniquePtr clientSock(
1204     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1205   AsyncSSLSocket::UniquePtr serverSock(
1206     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1207
1208   SSLHandshakeClient client(std::move(clientSock), false, true);
1209   SSLHandshakeServer server(std::move(serverSock), true, true);
1210
1211   eventBase.loop();
1212
1213   EXPECT_TRUE(client.handshakeVerify_);
1214   EXPECT_TRUE(client.handshakeSuccess_);
1215   EXPECT_TRUE(!client.handshakeError_);
1216   EXPECT_LE(0, client.handshakeTime.count());
1217   EXPECT_TRUE(!server.handshakeVerify_);
1218   EXPECT_TRUE(server.handshakeSuccess_);
1219   EXPECT_TRUE(!server.handshakeError_);
1220   EXPECT_LE(0, server.handshakeTime.count());
1221 }
1222
1223 /**
1224  * Verify that specifying that no validation should be performed allows an
1225  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1226  * callback.
1227  */
1228 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1229   EventBase eventBase;
1230   auto clientCtx = std::make_shared<SSLContext>();
1231   auto dfServerCtx = std::make_shared<SSLContext>();
1232
1233   int fds[2];
1234   getfds(fds);
1235   getctx(clientCtx, dfServerCtx);
1236
1237   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1238   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1239
1240   AsyncSSLSocket::UniquePtr clientSock(
1241     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1242   AsyncSSLSocket::UniquePtr serverSock(
1243     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1244
1245   SSLHandshakeClient client(std::move(clientSock), false, false);
1246   SSLHandshakeServer server(std::move(serverSock), false, false);
1247
1248   eventBase.loop();
1249
1250   EXPECT_TRUE(!client.handshakeVerify_);
1251   EXPECT_TRUE(client.handshakeSuccess_);
1252   EXPECT_TRUE(!client.handshakeError_);
1253   EXPECT_LE(0, client.handshakeTime.count());
1254   EXPECT_TRUE(!server.handshakeVerify_);
1255   EXPECT_TRUE(server.handshakeSuccess_);
1256   EXPECT_TRUE(!server.handshakeError_);
1257   EXPECT_LE(0, server.handshakeTime.count());
1258 }
1259
1260 /**
1261  * Test requireClientCert with client cert
1262  */
1263 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1264   EventBase eventBase;
1265   auto clientCtx = std::make_shared<SSLContext>();
1266   auto serverCtx = std::make_shared<SSLContext>();
1267   serverCtx->setVerificationOption(
1268       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1269   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1270   serverCtx->loadPrivateKey(testKey);
1271   serverCtx->loadCertificate(testCert);
1272   serverCtx->loadTrustedCertificates(testCA);
1273   serverCtx->loadClientCAList(testCA);
1274
1275   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1276   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1277   clientCtx->loadPrivateKey(testKey);
1278   clientCtx->loadCertificate(testCert);
1279   clientCtx->loadTrustedCertificates(testCA);
1280
1281   int fds[2];
1282   getfds(fds);
1283
1284   AsyncSSLSocket::UniquePtr clientSock(
1285       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1286   AsyncSSLSocket::UniquePtr serverSock(
1287       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1288
1289   SSLHandshakeClient client(std::move(clientSock), true, true);
1290   SSLHandshakeServer server(std::move(serverSock), true, true);
1291
1292   eventBase.loop();
1293
1294   EXPECT_TRUE(client.handshakeVerify_);
1295   EXPECT_TRUE(client.handshakeSuccess_);
1296   EXPECT_FALSE(client.handshakeError_);
1297   EXPECT_LE(0, client.handshakeTime.count());
1298   EXPECT_TRUE(server.handshakeVerify_);
1299   EXPECT_TRUE(server.handshakeSuccess_);
1300   EXPECT_FALSE(server.handshakeError_);
1301   EXPECT_LE(0, server.handshakeTime.count());
1302 }
1303
1304
1305 /**
1306  * Test requireClientCert with no client cert
1307  */
1308 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1309   EventBase eventBase;
1310   auto clientCtx = std::make_shared<SSLContext>();
1311   auto serverCtx = std::make_shared<SSLContext>();
1312   serverCtx->setVerificationOption(
1313       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1314   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1315   serverCtx->loadPrivateKey(testKey);
1316   serverCtx->loadCertificate(testCert);
1317   serverCtx->loadTrustedCertificates(testCA);
1318   serverCtx->loadClientCAList(testCA);
1319   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1320   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1321
1322   int fds[2];
1323   getfds(fds);
1324
1325   AsyncSSLSocket::UniquePtr clientSock(
1326       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1327   AsyncSSLSocket::UniquePtr serverSock(
1328       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1329
1330   SSLHandshakeClient client(std::move(clientSock), false, false);
1331   SSLHandshakeServer server(std::move(serverSock), false, false);
1332
1333   eventBase.loop();
1334
1335   EXPECT_FALSE(server.handshakeVerify_);
1336   EXPECT_FALSE(server.handshakeSuccess_);
1337   EXPECT_TRUE(server.handshakeError_);
1338   EXPECT_LE(0, client.handshakeTime.count());
1339   EXPECT_LE(0, server.handshakeTime.count());
1340 }
1341
1342 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1343   EventBase eb;
1344
1345   // Set up SSL context.
1346   auto sslContext = std::make_shared<SSLContext>();
1347   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1348
1349   // create SSL socket
1350   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1351
1352   EXPECT_EQ(1500, socket->getMinWriteSize());
1353
1354   socket->setMinWriteSize(0);
1355   EXPECT_EQ(0, socket->getMinWriteSize());
1356   socket->setMinWriteSize(50000);
1357   EXPECT_EQ(50000, socket->getMinWriteSize());
1358 }
1359
1360 class ReadCallbackTerminator : public ReadCallback {
1361  public:
1362   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1363       : ReadCallback(wcb)
1364       , base_(base) {}
1365
1366   // Do not write data back, terminate the loop.
1367   void readDataAvailable(size_t len) noexcept override {
1368     std::cerr << "readDataAvailable, len " << len << std::endl;
1369
1370     currentBuffer.length = len;
1371
1372     buffers.push_back(currentBuffer);
1373     currentBuffer.reset();
1374     state = STATE_SUCCEEDED;
1375
1376     socket_->setReadCB(nullptr);
1377     base_->terminateLoopSoon();
1378   }
1379  private:
1380   EventBase* base_;
1381 };
1382
1383
1384 /**
1385  * Test a full unencrypted codepath
1386  */
1387 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1388   EventBase base;
1389
1390   auto clientCtx = std::make_shared<folly::SSLContext>();
1391   auto serverCtx = std::make_shared<folly::SSLContext>();
1392   int fds[2];
1393   getfds(fds);
1394   getctx(clientCtx, serverCtx);
1395   auto client = AsyncSSLSocket::newSocket(
1396                   clientCtx, &base, fds[0], false, true);
1397   auto server = AsyncSSLSocket::newSocket(
1398                   serverCtx, &base, fds[1], true, true);
1399
1400   ReadCallbackTerminator readCallback(&base, nullptr);
1401   server->setReadCB(&readCallback);
1402   readCallback.setSocket(server);
1403
1404   uint8_t buf[128];
1405   memset(buf, 'a', sizeof(buf));
1406   client->write(nullptr, buf, sizeof(buf));
1407
1408   // Check that bytes are unencrypted
1409   char c;
1410   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1411   EXPECT_EQ('a', c);
1412
1413   EventBaseAborter eba(&base, 3000);
1414   base.loop();
1415
1416   EXPECT_EQ(1, readCallback.buffers.size());
1417   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1418
1419   server->setReadCB(&readCallback);
1420
1421   // Unencrypted
1422   server->sslAccept(nullptr);
1423   client->sslConn(nullptr);
1424
1425   // Do NOT wait for handshake, writing should be queued and happen after
1426
1427   client->write(nullptr, buf, sizeof(buf));
1428
1429   // Check that bytes are *not* unencrypted
1430   char c2;
1431   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1432   EXPECT_NE('a', c2);
1433
1434
1435   base.loop();
1436
1437   EXPECT_EQ(2, readCallback.buffers.size());
1438   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1439 }
1440
1441 } // namespace
1442
1443 ///////////////////////////////////////////////////////////////////////////
1444 // init_unit_test_suite
1445 ///////////////////////////////////////////////////////////////////////////
1446 namespace {
1447 struct Initializer {
1448   Initializer() {
1449     signal(SIGPIPE, SIG_IGN);
1450   }
1451 };
1452 Initializer initializer;
1453 } // anonymous