add callback to specify a client next protocol filter
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/SocketAddress.h>
24
25 #include <folly/io/async/test/BlockingSocket.h>
26
27 #include <gtest/gtest.h>
28 #include <iostream>
29 #include <list>
30 #include <set>
31 #include <unistd.h>
32 #include <fcntl.h>
33 #include <poll.h>
34 #include <sys/types.h>
35 #include <sys/socket.h>
36 #include <netinet/tcp.h>
37 #include <folly/io/Cursor.h>
38
39 using std::string;
40 using std::vector;
41 using std::min;
42 using std::cerr;
43 using std::endl;
44 using std::list;
45
46 namespace folly {
47 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
48 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
49 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
50
51 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
52 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
53 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
54
55 constexpr size_t SSLClient::kMaxReadBufferSz;
56 constexpr size_t SSLClient::kMaxReadsPerEvent;
57
58 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) :
59 ctx_(new folly::SSLContext),
60     acb_(acb),
61   socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
62   // Set up the SSL context
63   ctx_->loadCertificate(testCert);
64   ctx_->loadPrivateKey(testKey);
65   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
66
67   acb_->ctx_ = ctx_;
68   acb_->base_ = &evb_;
69
70   //set up the listening socket
71   socket_->bind(0);
72   socket_->getAddress(&address_);
73   socket_->listen(100);
74   socket_->addAcceptCallback(acb_, &evb_);
75   socket_->startAccepting();
76
77   int ret = pthread_create(&thread_, nullptr, Main, this);
78   assert(ret == 0);
79
80   std::cerr << "Accepting connections on " << address_ << std::endl;
81 }
82
83 void getfds(int fds[2]) {
84   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
85     FAIL() << "failed to create socketpair: " << strerror(errno);
86   }
87   for (int idx = 0; idx < 2; ++idx) {
88     int flags = fcntl(fds[idx], F_GETFL, 0);
89     if (flags == -1) {
90       FAIL() << "failed to get flags for socket " << idx << ": "
91              << strerror(errno);
92     }
93     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
94       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
95              << strerror(errno);
96     }
97   }
98 }
99
100 void getctx(
101   std::shared_ptr<folly::SSLContext> clientCtx,
102   std::shared_ptr<folly::SSLContext> serverCtx) {
103   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
104
105   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
106   serverCtx->loadCertificate(
107       testCert);
108   serverCtx->loadPrivateKey(
109       testKey);
110 }
111
112 void sslsocketpair(
113   EventBase* eventBase,
114   AsyncSSLSocket::UniquePtr* clientSock,
115   AsyncSSLSocket::UniquePtr* serverSock) {
116   auto clientCtx = std::make_shared<folly::SSLContext>();
117   auto serverCtx = std::make_shared<folly::SSLContext>();
118   int fds[2];
119   getfds(fds);
120   getctx(clientCtx, serverCtx);
121   clientSock->reset(new AsyncSSLSocket(
122                       clientCtx, eventBase, fds[0], false));
123   serverSock->reset(new AsyncSSLSocket(
124                       serverCtx, eventBase, fds[1], true));
125
126   // (*clientSock)->setSendTimeout(100);
127   // (*serverSock)->setSendTimeout(100);
128 }
129
130 // client protocol filters
131 bool clientProtoFilterPickPony(unsigned char** client,
132   unsigned int* client_len, const unsigned char*, unsigned int ) {
133   //the protocol string in length prefixed byte string. the
134   //length byte is not included in the length
135   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
136   *client = p;
137   *client_len = 7;
138   return true;
139 }
140
141 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
142   const unsigned char*, unsigned int) {
143   return false;
144 }
145
146 /**
147  * Test connecting to, writing to, reading from, and closing the
148  * connection to the SSL server.
149  */
150 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
151   // Start listening on a local port
152   WriteCallbackBase writeCallback;
153   ReadCallback readCallback(&writeCallback);
154   HandshakeCallback handshakeCallback(&readCallback);
155   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
156   TestSSLServer server(&acceptCallback);
157
158   // Set up SSL context.
159   std::shared_ptr<SSLContext> sslContext(new SSLContext());
160   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
161   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
162   //sslContext->authenticate(true, false);
163
164   // connect
165   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
166                                                  sslContext);
167   socket->open();
168
169   // write()
170   uint8_t buf[128];
171   memset(buf, 'a', sizeof(buf));
172   socket->write(buf, sizeof(buf));
173
174   // read()
175   uint8_t readbuf[128];
176   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
177   EXPECT_EQ(bytesRead, 128);
178   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
179
180   // close()
181   socket->close();
182
183   cerr << "ConnectWriteReadClose test completed" << endl;
184 }
185
186 /**
187  * Negative test for handshakeError().
188  */
189 TEST(AsyncSSLSocketTest, HandshakeError) {
190   // Start listening on a local port
191   WriteCallbackBase writeCallback;
192   ReadCallback readCallback(&writeCallback);
193   HandshakeCallback handshakeCallback(&readCallback);
194   HandshakeErrorCallback acceptCallback(&handshakeCallback);
195   TestSSLServer server(&acceptCallback);
196
197   // Set up SSL context.
198   std::shared_ptr<SSLContext> sslContext(new SSLContext());
199   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
200
201   // connect
202   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
203                                                  sslContext);
204   // read()
205   bool ex = false;
206   try {
207     socket->open();
208
209     uint8_t readbuf[128];
210     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
211   } catch (AsyncSocketException &e) {
212     ex = true;
213   }
214   EXPECT_TRUE(ex);
215
216   // close()
217   socket->close();
218   cerr << "HandshakeError test completed" << endl;
219 }
220
221 /**
222  * Negative test for readError().
223  */
224 TEST(AsyncSSLSocketTest, ReadError) {
225   // Start listening on a local port
226   WriteCallbackBase writeCallback;
227   ReadErrorCallback readCallback(&writeCallback);
228   HandshakeCallback handshakeCallback(&readCallback);
229   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
230   TestSSLServer server(&acceptCallback);
231
232   // Set up SSL context.
233   std::shared_ptr<SSLContext> sslContext(new SSLContext());
234   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
235
236   // connect
237   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
238                                                  sslContext);
239   socket->open();
240
241   // write something to trigger ssl handshake
242   uint8_t buf[128];
243   memset(buf, 'a', sizeof(buf));
244   socket->write(buf, sizeof(buf));
245
246   socket->close();
247   cerr << "ReadError test completed" << endl;
248 }
249
250 /**
251  * Negative test for writeError().
252  */
253 TEST(AsyncSSLSocketTest, WriteError) {
254   // Start listening on a local port
255   WriteCallbackBase writeCallback;
256   WriteErrorCallback readCallback(&writeCallback);
257   HandshakeCallback handshakeCallback(&readCallback);
258   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
259   TestSSLServer server(&acceptCallback);
260
261   // Set up SSL context.
262   std::shared_ptr<SSLContext> sslContext(new SSLContext());
263   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
264
265   // connect
266   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
267                                                  sslContext);
268   socket->open();
269
270   // write something to trigger ssl handshake
271   uint8_t buf[128];
272   memset(buf, 'a', sizeof(buf));
273   socket->write(buf, sizeof(buf));
274
275   socket->close();
276   cerr << "WriteError test completed" << endl;
277 }
278
279 /**
280  * Test a socket with TCP_NODELAY unset.
281  */
282 TEST(AsyncSSLSocketTest, SocketWithDelay) {
283   // Start listening on a local port
284   WriteCallbackBase writeCallback;
285   ReadCallback readCallback(&writeCallback);
286   HandshakeCallback handshakeCallback(&readCallback);
287   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
288   TestSSLServer server(&acceptCallback);
289
290   // Set up SSL context.
291   std::shared_ptr<SSLContext> sslContext(new SSLContext());
292   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
293
294   // connect
295   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
296                                                  sslContext);
297   socket->open();
298
299   // write()
300   uint8_t buf[128];
301   memset(buf, 'a', sizeof(buf));
302   socket->write(buf, sizeof(buf));
303
304   // read()
305   uint8_t readbuf[128];
306   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
307   EXPECT_EQ(bytesRead, 128);
308   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
309
310   // close()
311   socket->close();
312
313   cerr << "SocketWithDelay test completed" << endl;
314 }
315
316 TEST(AsyncSSLSocketTest, NpnTestOverlap) {
317   EventBase eventBase;
318   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
319   std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
320   int fds[2];
321   getfds(fds);
322   getctx(clientCtx, serverCtx);
323
324   clientCtx->setAdvertisedNextProtocols({"blub","baz"});
325   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
326
327   AsyncSSLSocket::UniquePtr clientSock(
328     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
329   AsyncSSLSocket::UniquePtr serverSock(
330     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
331   NpnClient client(std::move(clientSock));
332   NpnServer server(std::move(serverSock));
333
334   eventBase.loop();
335
336   EXPECT_TRUE(client.nextProtoLength != 0);
337   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
338   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
339                            server.nextProtoLength), 0);
340   string selected((const char*)client.nextProto, client.nextProtoLength);
341   EXPECT_EQ(selected.compare("baz"), 0);
342 }
343
344 TEST(AsyncSSLSocketTest, NpnTestUnset) {
345   // Identical to above test, except that we want unset NPN before
346   // looping.
347   EventBase eventBase;
348   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
349   std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
350   int fds[2];
351   getfds(fds);
352   getctx(clientCtx, serverCtx);
353
354   clientCtx->setAdvertisedNextProtocols({"blub","baz"});
355   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
356
357   AsyncSSLSocket::UniquePtr clientSock(
358     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
359   AsyncSSLSocket::UniquePtr serverSock(
360     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
361
362   // unsetting NPN for any of [client, server] is enought to make NPN not
363   // work
364   clientCtx->unsetNextProtocols();
365
366   NpnClient client(std::move(clientSock));
367   NpnServer server(std::move(serverSock));
368
369   eventBase.loop();
370
371   EXPECT_TRUE(client.nextProtoLength == 0);
372   EXPECT_TRUE(server.nextProtoLength == 0);
373   EXPECT_TRUE(client.nextProto == nullptr);
374   EXPECT_TRUE(server.nextProto == nullptr);
375 }
376
377 TEST(AsyncSSLSocketTest, NpnTestNoOverlap) {
378   EventBase eventBase;
379   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
380   std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
381   int fds[2];
382   getfds(fds);
383   getctx(clientCtx, serverCtx);
384
385   clientCtx->setAdvertisedNextProtocols({"blub"});
386   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
387
388   AsyncSSLSocket::UniquePtr clientSock(
389     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
390   AsyncSSLSocket::UniquePtr serverSock(
391     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
392   NpnClient client(std::move(clientSock));
393   NpnServer server(std::move(serverSock));
394
395   eventBase.loop();
396
397   EXPECT_TRUE(client.nextProtoLength != 0);
398   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
399   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
400                            server.nextProtoLength), 0);
401   string selected((const char*)client.nextProto, client.nextProtoLength);
402   EXPECT_EQ(selected.compare("blub"), 0);
403 }
404
405 TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterHit) {
406   EventBase eventBase;
407   auto clientCtx = std::make_shared<SSLContext>();
408   auto serverCtx = std::make_shared<SSLContext>();
409   int fds[2];
410   getfds(fds);
411   getctx(clientCtx, serverCtx);
412
413   clientCtx->setAdvertisedNextProtocols({"blub"});
414   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
415   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
416
417   AsyncSSLSocket::UniquePtr clientSock(
418     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
419   AsyncSSLSocket::UniquePtr serverSock(
420     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
421   NpnClient client(std::move(clientSock));
422   NpnServer server(std::move(serverSock));
423
424   eventBase.loop();
425
426   EXPECT_TRUE(client.nextProtoLength != 0);
427   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
428   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
429                            server.nextProtoLength), 0);
430   string selected((const char*)client.nextProto, client.nextProtoLength);
431   EXPECT_EQ(selected.compare("ponies"), 0);
432 }
433
434 TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterMiss) {
435   EventBase eventBase;
436   auto clientCtx = std::make_shared<SSLContext>();
437   auto serverCtx = std::make_shared<SSLContext>();
438   int fds[2];
439   getfds(fds);
440   getctx(clientCtx, serverCtx);
441
442   clientCtx->setAdvertisedNextProtocols({"blub"});
443   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
444   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
445
446   AsyncSSLSocket::UniquePtr clientSock(
447     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
448   AsyncSSLSocket::UniquePtr serverSock(
449     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
450   NpnClient client(std::move(clientSock));
451   NpnServer server(std::move(serverSock));
452
453   eventBase.loop();
454
455   EXPECT_TRUE(client.nextProtoLength != 0);
456   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
457   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
458                            server.nextProtoLength), 0);
459   string selected((const char*)client.nextProto, client.nextProtoLength);
460   EXPECT_EQ(selected.compare("blub"), 0);
461 }
462
463 TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
464   // Probability that this test will fail is 2^-64, which could be considered
465   // as negligible.
466   const int kTries = 64;
467
468   std::set<string> selectedProtocols;
469   for (int i = 0; i < kTries; ++i) {
470     EventBase eventBase;
471     std::shared_ptr<SSLContext> clientCtx = std::make_shared<SSLContext>();
472     std::shared_ptr<SSLContext> serverCtx = std::make_shared<SSLContext>();
473     int fds[2];
474     getfds(fds);
475     getctx(clientCtx, serverCtx);
476
477     clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
478     serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}},
479         {1, {"bar"}}});
480
481
482     AsyncSSLSocket::UniquePtr clientSock(
483       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
484     AsyncSSLSocket::UniquePtr serverSock(
485       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
486     NpnClient client(std::move(clientSock));
487     NpnServer server(std::move(serverSock));
488
489     eventBase.loop();
490
491     EXPECT_TRUE(client.nextProtoLength != 0);
492     EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
493     EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
494                              server.nextProtoLength), 0);
495     string selected((const char*)client.nextProto, client.nextProtoLength);
496     selectedProtocols.insert(selected);
497   }
498   EXPECT_EQ(selectedProtocols.size(), 2);
499 }
500
501
502 #ifndef OPENSSL_NO_TLSEXT
503 /**
504  * 1. Client sends TLSEXT_HOSTNAME in client hello.
505  * 2. Server found a match SSL_CTX and use this SSL_CTX to
506  *    continue the SSL handshake.
507  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
508  */
509 TEST(AsyncSSLSocketTest, SNITestMatch) {
510   EventBase eventBase;
511   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
512   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
513   // Use the same SSLContext to continue the handshake after
514   // tlsext_hostname match.
515   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
516   const std::string serverName("xyz.newdev.facebook.com");
517   int fds[2];
518   getfds(fds);
519   getctx(clientCtx, dfServerCtx);
520
521   AsyncSSLSocket::UniquePtr clientSock(
522     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
523   AsyncSSLSocket::UniquePtr serverSock(
524     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
525   SNIClient client(std::move(clientSock));
526   SNIServer server(std::move(serverSock),
527                    dfServerCtx,
528                    hskServerCtx,
529                    serverName);
530
531   eventBase.loop();
532
533   EXPECT_TRUE(client.serverNameMatch);
534   EXPECT_TRUE(server.serverNameMatch);
535 }
536
537 /**
538  * 1. Client sends TLSEXT_HOSTNAME in client hello.
539  * 2. Server cannot find a matching SSL_CTX and continue to use
540  *    the current SSL_CTX to do the handshake.
541  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
542  */
543 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
544   EventBase eventBase;
545   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
546   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
547   // Use the same SSLContext to continue the handshake after
548   // tlsext_hostname match.
549   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
550   const std::string clientRequestingServerName("foo.com");
551   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
552
553   int fds[2];
554   getfds(fds);
555   getctx(clientCtx, dfServerCtx);
556
557   AsyncSSLSocket::UniquePtr clientSock(
558     new AsyncSSLSocket(clientCtx,
559                         &eventBase,
560                         fds[0],
561                         clientRequestingServerName));
562   AsyncSSLSocket::UniquePtr serverSock(
563     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
564   SNIClient client(std::move(clientSock));
565   SNIServer server(std::move(serverSock),
566                    dfServerCtx,
567                    hskServerCtx,
568                    serverExpectedServerName);
569
570   eventBase.loop();
571
572   EXPECT_TRUE(!client.serverNameMatch);
573   EXPECT_TRUE(!server.serverNameMatch);
574 }
575 /**
576  * 1. Client sends TLSEXT_HOSTNAME in client hello.
577  * 2. We then change the serverName.
578  * 3. We expect that we get 'false' as the result for serNameMatch.
579  */
580
581 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
582    EventBase eventBase;
583   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
584   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
585   // Use the same SSLContext to continue the handshake after
586   // tlsext_hostname match.
587   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
588   const std::string serverName("xyz.newdev.facebook.com");
589   int fds[2];
590   getfds(fds);
591   getctx(clientCtx, dfServerCtx);
592
593   AsyncSSLSocket::UniquePtr clientSock(
594     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
595   //Change the server name
596   std::string newName("new.com");
597   clientSock->setServerName(newName);
598   AsyncSSLSocket::UniquePtr serverSock(
599     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
600   SNIClient client(std::move(clientSock));
601   SNIServer server(std::move(serverSock),
602                    dfServerCtx,
603                    hskServerCtx,
604                    serverName);
605
606   eventBase.loop();
607
608   EXPECT_TRUE(!client.serverNameMatch);
609 }
610
611 /**
612  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
613  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
614  */
615 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
616   EventBase eventBase;
617   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
618   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
619   // Use the same SSLContext to continue the handshake after
620   // tlsext_hostname match.
621   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
622   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
623
624   int fds[2];
625   getfds(fds);
626   getctx(clientCtx, dfServerCtx);
627
628   AsyncSSLSocket::UniquePtr clientSock(
629     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
630   AsyncSSLSocket::UniquePtr serverSock(
631     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
632   SNIClient client(std::move(clientSock));
633   SNIServer server(std::move(serverSock),
634                    dfServerCtx,
635                    hskServerCtx,
636                    serverExpectedServerName);
637
638   eventBase.loop();
639
640   EXPECT_TRUE(!client.serverNameMatch);
641   EXPECT_TRUE(!server.serverNameMatch);
642 }
643
644 #endif
645 /**
646  * Test SSL client socket
647  */
648 TEST(AsyncSSLSocketTest, SSLClientTest) {
649   // Start listening on a local port
650   WriteCallbackBase writeCallback;
651   ReadCallback readCallback(&writeCallback);
652   HandshakeCallback handshakeCallback(&readCallback);
653   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
654   TestSSLServer server(&acceptCallback);
655
656   // Set up SSL client
657   EventBase eventBase;
658   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
659                                              1));
660
661   client->connect();
662   EventBaseAborter eba(&eventBase, 3000);
663   eventBase.loop();
664
665   EXPECT_EQ(client->getMiss(), 1);
666   EXPECT_EQ(client->getHit(), 0);
667
668   cerr << "SSLClientTest test completed" << endl;
669 }
670
671
672 /**
673  * Test SSL client socket session re-use
674  */
675 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
676   // Start listening on a local port
677   WriteCallbackBase writeCallback;
678   ReadCallback readCallback(&writeCallback);
679   HandshakeCallback handshakeCallback(&readCallback);
680   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
681   TestSSLServer server(&acceptCallback);
682
683   // Set up SSL client
684   EventBase eventBase;
685   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
686                                              10));
687
688   client->connect();
689   EventBaseAborter eba(&eventBase, 3000);
690   eventBase.loop();
691
692   EXPECT_EQ(client->getMiss(), 1);
693   EXPECT_EQ(client->getHit(), 9);
694
695   cerr << "SSLClientTestReuse test completed" << endl;
696 }
697
698 /**
699  * Test SSL client socket timeout
700  */
701 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
702   // Start listening on a local port
703   EmptyReadCallback readCallback;
704   HandshakeCallback handshakeCallback(&readCallback,
705                                       HandshakeCallback::EXPECT_ERROR);
706   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
707   TestSSLServer server(&acceptCallback);
708
709   // Set up SSL client
710   EventBase eventBase;
711   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
712                                              1, 10));
713   client->connect(true /* write before connect completes */);
714   EventBaseAborter eba(&eventBase, 3000);
715   eventBase.loop();
716
717   usleep(100000);
718   // This is checking that the connectError callback precedes any queued
719   // writeError callbacks.  This matches AsyncSocket's behavior
720   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
721   EXPECT_EQ(client->getErrors(), 1);
722   EXPECT_EQ(client->getMiss(), 0);
723   EXPECT_EQ(client->getHit(), 0);
724
725   cerr << "SSLClientTimeoutTest test completed" << endl;
726 }
727
728
729 /**
730  * Test SSL server async cache
731  */
732 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
733   // Start listening on a local port
734   WriteCallbackBase writeCallback;
735   ReadCallback readCallback(&writeCallback);
736   HandshakeCallback handshakeCallback(&readCallback);
737   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
738   TestSSLAsyncCacheServer server(&acceptCallback);
739
740   // Set up SSL client
741   EventBase eventBase;
742   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
743                                              10, 500));
744
745   client->connect();
746   EventBaseAborter eba(&eventBase, 3000);
747   eventBase.loop();
748
749   EXPECT_EQ(server.getAsyncCallbacks(), 18);
750   EXPECT_EQ(server.getAsyncLookups(), 9);
751   EXPECT_EQ(client->getMiss(), 10);
752   EXPECT_EQ(client->getHit(), 0);
753
754   cerr << "SSLServerAsyncCacheTest test completed" << endl;
755 }
756
757
758 /**
759  * Test SSL server accept timeout with cache path
760  */
761 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
762   // Start listening on a local port
763   WriteCallbackBase writeCallback;
764   ReadCallback readCallback(&writeCallback);
765   EmptyReadCallback clientReadCallback;
766   HandshakeCallback handshakeCallback(&readCallback);
767   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
768   TestSSLAsyncCacheServer server(&acceptCallback);
769
770   // Set up SSL client
771   EventBase eventBase;
772   // only do a TCP connect
773   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
774   sock->connect(nullptr, server.getAddress());
775   clientReadCallback.tcpSocket_ = sock;
776   sock->setReadCB(&clientReadCallback);
777
778   EventBaseAborter eba(&eventBase, 3000);
779   eventBase.loop();
780
781   EXPECT_EQ(readCallback.state, STATE_WAITING);
782
783   cerr << "SSLServerTimeoutTest test completed" << endl;
784 }
785
786 /**
787  * Test SSL server accept timeout with cache path
788  */
789 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
790   // Start listening on a local port
791   WriteCallbackBase writeCallback;
792   ReadCallback readCallback(&writeCallback);
793   HandshakeCallback handshakeCallback(&readCallback);
794   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
795   TestSSLAsyncCacheServer server(&acceptCallback);
796
797   // Set up SSL client
798   EventBase eventBase;
799   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
800                                              2));
801
802   client->connect();
803   EventBaseAborter eba(&eventBase, 3000);
804   eventBase.loop();
805
806   EXPECT_EQ(server.getAsyncCallbacks(), 1);
807   EXPECT_EQ(server.getAsyncLookups(), 1);
808   EXPECT_EQ(client->getErrors(), 1);
809   EXPECT_EQ(client->getMiss(), 1);
810   EXPECT_EQ(client->getHit(), 0);
811
812   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
813 }
814
815 /**
816  * Test SSL server accept timeout with cache path
817  */
818 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
819   // Start listening on a local port
820   WriteCallbackBase writeCallback;
821   ReadCallback readCallback(&writeCallback);
822   HandshakeCallback handshakeCallback(&readCallback,
823                                       HandshakeCallback::EXPECT_ERROR);
824   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
825   TestSSLAsyncCacheServer server(&acceptCallback, 500);
826
827   // Set up SSL client
828   EventBase eventBase;
829   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
830                                              2, 100));
831
832   client->connect();
833   EventBaseAborter eba(&eventBase, 3000);
834   eventBase.loop();
835
836   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
837       handshakeCallback.closeSocket();});
838   // give time for the cache lookup to come back and find it closed
839   usleep(500000);
840
841   EXPECT_EQ(server.getAsyncCallbacks(), 1);
842   EXPECT_EQ(server.getAsyncLookups(), 1);
843   EXPECT_EQ(client->getErrors(), 1);
844   EXPECT_EQ(client->getMiss(), 1);
845   EXPECT_EQ(client->getHit(), 0);
846
847   cerr << "SSLServerCacheCloseTest test completed" << endl;
848 }
849
850 /**
851  * Verify Client Ciphers obtained using SSL MSG Callback.
852  */
853 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
854   EventBase eventBase;
855   auto clientCtx = std::make_shared<SSLContext>();
856   auto serverCtx = std::make_shared<SSLContext>();
857   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
858   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
859   serverCtx->loadPrivateKey(testKey);
860   serverCtx->loadCertificate(testCert);
861   serverCtx->loadTrustedCertificates(testCA);
862   serverCtx->loadClientCAList(testCA);
863
864   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
865   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
866   clientCtx->loadPrivateKey(testKey);
867   clientCtx->loadCertificate(testCert);
868   clientCtx->loadTrustedCertificates(testCA);
869
870   int fds[2];
871   getfds(fds);
872
873   AsyncSSLSocket::UniquePtr clientSock(
874       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
875   AsyncSSLSocket::UniquePtr serverSock(
876       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
877
878   SSLHandshakeClient client(std::move(clientSock), true, true);
879   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
880
881   eventBase.loop();
882
883   EXPECT_EQ(server.clientCiphers_,
884             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
885   EXPECT_TRUE(client.handshakeVerify_);
886   EXPECT_TRUE(client.handshakeSuccess_);
887   EXPECT_TRUE(!client.handshakeError_);
888   EXPECT_TRUE(server.handshakeVerify_);
889   EXPECT_TRUE(server.handshakeSuccess_);
890   EXPECT_TRUE(!server.handshakeError_);
891 }
892
893 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
894   EventBase eventBase;
895   auto ctx = std::make_shared<SSLContext>();
896
897   int fds[2];
898   getfds(fds);
899
900   int bufLen = 42;
901   uint8_t majorVersion = 18;
902   uint8_t minorVersion = 25;
903
904   // Create callback buf
905   auto buf = IOBuf::create(bufLen);
906   buf->append(bufLen);
907   folly::io::RWPrivateCursor cursor(buf.get());
908   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
909   cursor.write<uint16_t>(0);
910   cursor.write<uint8_t>(38);
911   cursor.write<uint8_t>(majorVersion);
912   cursor.write<uint8_t>(minorVersion);
913   cursor.skip(32);
914   cursor.write<uint32_t>(0);
915
916   SSL* ssl = ctx->createSSL();
917   SCOPE_EXIT { SSL_free(ssl); };
918   AsyncSSLSocket::UniquePtr sock(
919       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
920   sock->enableClientHelloParsing();
921
922   // Test client hello parsing in one packet
923   AsyncSSLSocket::clientHelloParsingCallback(
924       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
925   buf.reset();
926
927   auto parsedClientHello = sock->getClientHelloInfo();
928   EXPECT_TRUE(parsedClientHello != nullptr);
929   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
930   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
931 }
932
933 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
934   EventBase eventBase;
935   auto ctx = std::make_shared<SSLContext>();
936
937   int fds[2];
938   getfds(fds);
939
940   int bufLen = 42;
941   uint8_t majorVersion = 18;
942   uint8_t minorVersion = 25;
943
944   // Create callback buf
945   auto buf = IOBuf::create(bufLen);
946   buf->append(bufLen);
947   folly::io::RWPrivateCursor cursor(buf.get());
948   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
949   cursor.write<uint16_t>(0);
950   cursor.write<uint8_t>(38);
951   cursor.write<uint8_t>(majorVersion);
952   cursor.write<uint8_t>(minorVersion);
953   cursor.skip(32);
954   cursor.write<uint32_t>(0);
955
956   SSL* ssl = ctx->createSSL();
957   SCOPE_EXIT { SSL_free(ssl); };
958   AsyncSSLSocket::UniquePtr sock(
959       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
960   sock->enableClientHelloParsing();
961
962   // Test parsing with two packets with first packet size < 3
963   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
964   AsyncSSLSocket::clientHelloParsingCallback(
965       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
966       ssl, sock.get());
967   bufCopy.reset();
968   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
969   AsyncSSLSocket::clientHelloParsingCallback(
970       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
971       ssl, sock.get());
972   bufCopy.reset();
973
974   auto parsedClientHello = sock->getClientHelloInfo();
975   EXPECT_TRUE(parsedClientHello != nullptr);
976   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
977   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
978 }
979
980 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
981   EventBase eventBase;
982   auto ctx = std::make_shared<SSLContext>();
983
984   int fds[2];
985   getfds(fds);
986
987   int bufLen = 42;
988   uint8_t majorVersion = 18;
989   uint8_t minorVersion = 25;
990
991   // Create callback buf
992   auto buf = IOBuf::create(bufLen);
993   buf->append(bufLen);
994   folly::io::RWPrivateCursor cursor(buf.get());
995   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
996   cursor.write<uint16_t>(0);
997   cursor.write<uint8_t>(38);
998   cursor.write<uint8_t>(majorVersion);
999   cursor.write<uint8_t>(minorVersion);
1000   cursor.skip(32);
1001   cursor.write<uint32_t>(0);
1002
1003   SSL* ssl = ctx->createSSL();
1004   SCOPE_EXIT { SSL_free(ssl); };
1005   AsyncSSLSocket::UniquePtr sock(
1006       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1007   sock->enableClientHelloParsing();
1008
1009   // Test parsing with multiple small packets
1010   for (uint64_t i = 0; i < buf->length(); i += 3) {
1011     auto bufCopy = folly::IOBuf::copyBuffer(
1012         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1013     AsyncSSLSocket::clientHelloParsingCallback(
1014         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1015         ssl, sock.get());
1016     bufCopy.reset();
1017   }
1018
1019   auto parsedClientHello = sock->getClientHelloInfo();
1020   EXPECT_TRUE(parsedClientHello != nullptr);
1021   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1022   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1023 }
1024
1025 /**
1026  * Verify sucessful behavior of SSL certificate validation.
1027  */
1028 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1029   EventBase eventBase;
1030   auto clientCtx = std::make_shared<SSLContext>();
1031   auto dfServerCtx = std::make_shared<SSLContext>();
1032
1033   int fds[2];
1034   getfds(fds);
1035   getctx(clientCtx, dfServerCtx);
1036
1037   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1038   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1039
1040   AsyncSSLSocket::UniquePtr clientSock(
1041     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1042   AsyncSSLSocket::UniquePtr serverSock(
1043     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1044
1045   SSLHandshakeClient client(std::move(clientSock), true, true);
1046   clientCtx->loadTrustedCertificates(testCA);
1047
1048   SSLHandshakeServer server(std::move(serverSock), true, true);
1049
1050   eventBase.loop();
1051
1052   EXPECT_TRUE(client.handshakeVerify_);
1053   EXPECT_TRUE(client.handshakeSuccess_);
1054   EXPECT_TRUE(!client.handshakeError_);
1055   EXPECT_TRUE(!server.handshakeVerify_);
1056   EXPECT_TRUE(server.handshakeSuccess_);
1057   EXPECT_TRUE(!server.handshakeError_);
1058 }
1059
1060 /**
1061  * Verify that the client's verification callback is able to fail SSL
1062  * connection establishment.
1063  */
1064 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1065   EventBase eventBase;
1066   auto clientCtx = std::make_shared<SSLContext>();
1067   auto dfServerCtx = std::make_shared<SSLContext>();
1068
1069   int fds[2];
1070   getfds(fds);
1071   getctx(clientCtx, dfServerCtx);
1072
1073   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1074   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1075
1076   AsyncSSLSocket::UniquePtr clientSock(
1077     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1078   AsyncSSLSocket::UniquePtr serverSock(
1079     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1080
1081   SSLHandshakeClient client(std::move(clientSock), true, false);
1082   clientCtx->loadTrustedCertificates(testCA);
1083
1084   SSLHandshakeServer server(std::move(serverSock), true, true);
1085
1086   eventBase.loop();
1087
1088   EXPECT_TRUE(client.handshakeVerify_);
1089   EXPECT_TRUE(!client.handshakeSuccess_);
1090   EXPECT_TRUE(client.handshakeError_);
1091   EXPECT_TRUE(!server.handshakeVerify_);
1092   EXPECT_TRUE(!server.handshakeSuccess_);
1093   EXPECT_TRUE(server.handshakeError_);
1094 }
1095
1096 /**
1097  * Verify that the options in SSLContext can be overridden in
1098  * sslConnect/Accept.i.e specifying that no validation should be performed
1099  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1100  * the validation callback.
1101  */
1102 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1103   EventBase eventBase;
1104   auto clientCtx = std::make_shared<SSLContext>();
1105   auto dfServerCtx = std::make_shared<SSLContext>();
1106
1107   int fds[2];
1108   getfds(fds);
1109   getctx(clientCtx, dfServerCtx);
1110
1111   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1112   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1113
1114   AsyncSSLSocket::UniquePtr clientSock(
1115     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1116   AsyncSSLSocket::UniquePtr serverSock(
1117     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1118
1119   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1120   clientCtx->loadTrustedCertificates(testCA);
1121
1122   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1123
1124   eventBase.loop();
1125
1126   EXPECT_TRUE(!client.handshakeVerify_);
1127   EXPECT_TRUE(client.handshakeSuccess_);
1128   EXPECT_TRUE(!client.handshakeError_);
1129   EXPECT_TRUE(!server.handshakeVerify_);
1130   EXPECT_TRUE(server.handshakeSuccess_);
1131   EXPECT_TRUE(!server.handshakeError_);
1132 }
1133
1134 /**
1135  * Verify that the options in SSLContext can be overridden in
1136  * sslConnect/Accept. Enable verification even if context says otherwise.
1137  * Test requireClientCert with client cert
1138  */
1139 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1140   EventBase eventBase;
1141   auto clientCtx = std::make_shared<SSLContext>();
1142   auto serverCtx = std::make_shared<SSLContext>();
1143   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1144   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1145   serverCtx->loadPrivateKey(testKey);
1146   serverCtx->loadCertificate(testCert);
1147   serverCtx->loadTrustedCertificates(testCA);
1148   serverCtx->loadClientCAList(testCA);
1149
1150   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1151   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1152   clientCtx->loadPrivateKey(testKey);
1153   clientCtx->loadCertificate(testCert);
1154   clientCtx->loadTrustedCertificates(testCA);
1155
1156   int fds[2];
1157   getfds(fds);
1158
1159   AsyncSSLSocket::UniquePtr clientSock(
1160       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1161   AsyncSSLSocket::UniquePtr serverSock(
1162       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1163
1164   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1165   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1166
1167   eventBase.loop();
1168
1169   EXPECT_TRUE(client.handshakeVerify_);
1170   EXPECT_TRUE(client.handshakeSuccess_);
1171   EXPECT_FALSE(client.handshakeError_);
1172   EXPECT_TRUE(server.handshakeVerify_);
1173   EXPECT_TRUE(server.handshakeSuccess_);
1174   EXPECT_FALSE(server.handshakeError_);
1175 }
1176
1177 /**
1178  * Verify that the client's verification callback is able to override
1179  * the preverification failure and allow a successful connection.
1180  */
1181 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1182   EventBase eventBase;
1183   auto clientCtx = std::make_shared<SSLContext>();
1184   auto dfServerCtx = std::make_shared<SSLContext>();
1185
1186   int fds[2];
1187   getfds(fds);
1188   getctx(clientCtx, dfServerCtx);
1189
1190   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1191   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1192
1193   AsyncSSLSocket::UniquePtr clientSock(
1194     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1195   AsyncSSLSocket::UniquePtr serverSock(
1196     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1197
1198   SSLHandshakeClient client(std::move(clientSock), false, true);
1199   SSLHandshakeServer server(std::move(serverSock), true, true);
1200
1201   eventBase.loop();
1202
1203   EXPECT_TRUE(client.handshakeVerify_);
1204   EXPECT_TRUE(client.handshakeSuccess_);
1205   EXPECT_TRUE(!client.handshakeError_);
1206   EXPECT_TRUE(!server.handshakeVerify_);
1207   EXPECT_TRUE(server.handshakeSuccess_);
1208   EXPECT_TRUE(!server.handshakeError_);
1209 }
1210
1211 /**
1212  * Verify that specifying that no validation should be performed allows an
1213  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1214  * callback.
1215  */
1216 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1217   EventBase eventBase;
1218   auto clientCtx = std::make_shared<SSLContext>();
1219   auto dfServerCtx = std::make_shared<SSLContext>();
1220
1221   int fds[2];
1222   getfds(fds);
1223   getctx(clientCtx, dfServerCtx);
1224
1225   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1226   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1227
1228   AsyncSSLSocket::UniquePtr clientSock(
1229     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1230   AsyncSSLSocket::UniquePtr serverSock(
1231     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1232
1233   SSLHandshakeClient client(std::move(clientSock), false, false);
1234   SSLHandshakeServer server(std::move(serverSock), false, false);
1235
1236   eventBase.loop();
1237
1238   EXPECT_TRUE(!client.handshakeVerify_);
1239   EXPECT_TRUE(client.handshakeSuccess_);
1240   EXPECT_TRUE(!client.handshakeError_);
1241   EXPECT_TRUE(!server.handshakeVerify_);
1242   EXPECT_TRUE(server.handshakeSuccess_);
1243   EXPECT_TRUE(!server.handshakeError_);
1244 }
1245
1246 /**
1247  * Test requireClientCert with client cert
1248  */
1249 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1250   EventBase eventBase;
1251   auto clientCtx = std::make_shared<SSLContext>();
1252   auto serverCtx = std::make_shared<SSLContext>();
1253   serverCtx->setVerificationOption(
1254       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1255   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1256   serverCtx->loadPrivateKey(testKey);
1257   serverCtx->loadCertificate(testCert);
1258   serverCtx->loadTrustedCertificates(testCA);
1259   serverCtx->loadClientCAList(testCA);
1260
1261   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1262   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1263   clientCtx->loadPrivateKey(testKey);
1264   clientCtx->loadCertificate(testCert);
1265   clientCtx->loadTrustedCertificates(testCA);
1266
1267   int fds[2];
1268   getfds(fds);
1269
1270   AsyncSSLSocket::UniquePtr clientSock(
1271       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1272   AsyncSSLSocket::UniquePtr serverSock(
1273       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1274
1275   SSLHandshakeClient client(std::move(clientSock), true, true);
1276   SSLHandshakeServer server(std::move(serverSock), true, true);
1277
1278   eventBase.loop();
1279
1280   EXPECT_TRUE(client.handshakeVerify_);
1281   EXPECT_TRUE(client.handshakeSuccess_);
1282   EXPECT_FALSE(client.handshakeError_);
1283   EXPECT_TRUE(server.handshakeVerify_);
1284   EXPECT_TRUE(server.handshakeSuccess_);
1285   EXPECT_FALSE(server.handshakeError_);
1286 }
1287
1288
1289 /**
1290  * Test requireClientCert with no client cert
1291  */
1292 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1293   EventBase eventBase;
1294   auto clientCtx = std::make_shared<SSLContext>();
1295   auto serverCtx = std::make_shared<SSLContext>();
1296   serverCtx->setVerificationOption(
1297       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1298   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1299   serverCtx->loadPrivateKey(testKey);
1300   serverCtx->loadCertificate(testCert);
1301   serverCtx->loadTrustedCertificates(testCA);
1302   serverCtx->loadClientCAList(testCA);
1303   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1304   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1305
1306   int fds[2];
1307   getfds(fds);
1308
1309   AsyncSSLSocket::UniquePtr clientSock(
1310       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1311   AsyncSSLSocket::UniquePtr serverSock(
1312       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1313
1314   SSLHandshakeClient client(std::move(clientSock), false, false);
1315   SSLHandshakeServer server(std::move(serverSock), false, false);
1316
1317   eventBase.loop();
1318
1319   EXPECT_FALSE(server.handshakeVerify_);
1320   EXPECT_FALSE(server.handshakeSuccess_);
1321   EXPECT_TRUE(server.handshakeError_);
1322 }
1323
1324 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1325   EventBase eb;
1326
1327   // Set up SSL context.
1328   auto sslContext = std::make_shared<SSLContext>();
1329   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1330
1331   // create SSL socket
1332   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1333
1334   EXPECT_EQ(1500, socket->getMinWriteSize());
1335
1336   socket->setMinWriteSize(0);
1337   EXPECT_EQ(0, socket->getMinWriteSize());
1338   socket->setMinWriteSize(50000);
1339   EXPECT_EQ(50000, socket->getMinWriteSize());
1340 }
1341
1342 class ReadCallbackTerminator : public ReadCallback {
1343  public:
1344   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1345       : ReadCallback(wcb)
1346       , base_(base) {}
1347
1348   // Do not write data back, terminate the loop.
1349   void readDataAvailable(size_t len) noexcept override {
1350     std::cerr << "readDataAvailable, len " << len << std::endl;
1351
1352     currentBuffer.length = len;
1353
1354     buffers.push_back(currentBuffer);
1355     currentBuffer.reset();
1356     state = STATE_SUCCEEDED;
1357
1358     socket_->setReadCB(nullptr);
1359     base_->terminateLoopSoon();
1360   }
1361  private:
1362   EventBase* base_;
1363 };
1364
1365
1366 /**
1367  * Test a full unencrypted codepath
1368  */
1369 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1370   EventBase base;
1371
1372   auto clientCtx = std::make_shared<folly::SSLContext>();
1373   auto serverCtx = std::make_shared<folly::SSLContext>();
1374   int fds[2];
1375   getfds(fds);
1376   getctx(clientCtx, serverCtx);
1377   auto client = AsyncSSLSocket::newSocket(
1378                   clientCtx, &base, fds[0], false, true);
1379   auto server = AsyncSSLSocket::newSocket(
1380                   serverCtx, &base, fds[1], true, true);
1381
1382   ReadCallbackTerminator readCallback(&base, nullptr);
1383   server->setReadCB(&readCallback);
1384   readCallback.setSocket(server);
1385
1386   uint8_t buf[128];
1387   memset(buf, 'a', sizeof(buf));
1388   client->write(nullptr, buf, sizeof(buf));
1389
1390   // Check that bytes are unencrypted
1391   char c;
1392   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1393   EXPECT_EQ('a', c);
1394
1395   EventBaseAborter eba(&base, 3000);
1396   base.loop();
1397
1398   EXPECT_EQ(1, readCallback.buffers.size());
1399   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1400
1401   server->setReadCB(&readCallback);
1402
1403   // Unencrypted
1404   server->sslAccept(nullptr);
1405   client->sslConn(nullptr);
1406
1407   // Do NOT wait for handshake, writing should be queued and happen after
1408
1409   client->write(nullptr, buf, sizeof(buf));
1410
1411   // Check that bytes are *not* unencrypted
1412   char c2;
1413   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1414   EXPECT_NE('a', c2);
1415
1416
1417   base.loop();
1418
1419   EXPECT_EQ(2, readCallback.buffers.size());
1420   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1421 }
1422
1423 } // namespace
1424
1425 ///////////////////////////////////////////////////////////////////////////
1426 // init_unit_test_suite
1427 ///////////////////////////////////////////////////////////////////////////
1428 namespace {
1429 struct Initializer {
1430   Initializer() {
1431     signal(SIGPIPE, SIG_IGN);
1432   }
1433 };
1434 Initializer initializer;
1435 } // anonymous