fix compiler warnings from gcc-4.9 + -Wunused-variable
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/SocketAddress.h>
24
25 #include <folly/io/async/test/BlockingSocket.h>
26
27 #include <gtest/gtest.h>
28 #include <iostream>
29 #include <list>
30 #include <set>
31 #include <unistd.h>
32 #include <fcntl.h>
33 #include <poll.h>
34 #include <sys/types.h>
35 #include <sys/socket.h>
36 #include <netinet/tcp.h>
37 #include <folly/io/Cursor.h>
38
39 using std::string;
40 using std::vector;
41 using std::min;
42 using std::cerr;
43 using std::endl;
44 using std::list;
45
46 namespace folly {
47 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
48 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
49 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
50
51 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
52 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
53 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
54
55 constexpr size_t SSLClient::kMaxReadBufferSz;
56 constexpr size_t SSLClient::kMaxReadsPerEvent;
57
58 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) :
59 ctx_(new folly::SSLContext),
60     acb_(acb),
61   socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
62   // Set up the SSL context
63   ctx_->loadCertificate(testCert);
64   ctx_->loadPrivateKey(testKey);
65   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
66
67   acb_->ctx_ = ctx_;
68   acb_->base_ = &evb_;
69
70   //set up the listening socket
71   socket_->bind(0);
72   socket_->getAddress(&address_);
73   socket_->listen(100);
74   socket_->addAcceptCallback(acb_, &evb_);
75   socket_->startAccepting();
76
77   int ret = pthread_create(&thread_, nullptr, Main, this);
78   assert(ret == 0);
79   (void)ret;
80
81   std::cerr << "Accepting connections on " << address_ << std::endl;
82 }
83
84 void getfds(int fds[2]) {
85   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
86     FAIL() << "failed to create socketpair: " << strerror(errno);
87   }
88   for (int idx = 0; idx < 2; ++idx) {
89     int flags = fcntl(fds[idx], F_GETFL, 0);
90     if (flags == -1) {
91       FAIL() << "failed to get flags for socket " << idx << ": "
92              << strerror(errno);
93     }
94     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
95       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
96              << strerror(errno);
97     }
98   }
99 }
100
101 void getctx(
102   std::shared_ptr<folly::SSLContext> clientCtx,
103   std::shared_ptr<folly::SSLContext> serverCtx) {
104   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
105
106   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
107   serverCtx->loadCertificate(
108       testCert);
109   serverCtx->loadPrivateKey(
110       testKey);
111 }
112
113 void sslsocketpair(
114   EventBase* eventBase,
115   AsyncSSLSocket::UniquePtr* clientSock,
116   AsyncSSLSocket::UniquePtr* serverSock) {
117   auto clientCtx = std::make_shared<folly::SSLContext>();
118   auto serverCtx = std::make_shared<folly::SSLContext>();
119   int fds[2];
120   getfds(fds);
121   getctx(clientCtx, serverCtx);
122   clientSock->reset(new AsyncSSLSocket(
123                       clientCtx, eventBase, fds[0], false));
124   serverSock->reset(new AsyncSSLSocket(
125                       serverCtx, eventBase, fds[1], true));
126
127   // (*clientSock)->setSendTimeout(100);
128   // (*serverSock)->setSendTimeout(100);
129 }
130
131 // client protocol filters
132 bool clientProtoFilterPickPony(unsigned char** client,
133   unsigned int* client_len, const unsigned char*, unsigned int ) {
134   //the protocol string in length prefixed byte string. the
135   //length byte is not included in the length
136   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
137   *client = p;
138   *client_len = 7;
139   return true;
140 }
141
142 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
143   const unsigned char*, unsigned int) {
144   return false;
145 }
146
147 /**
148  * Test connecting to, writing to, reading from, and closing the
149  * connection to the SSL server.
150  */
151 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
152   // Start listening on a local port
153   WriteCallbackBase writeCallback;
154   ReadCallback readCallback(&writeCallback);
155   HandshakeCallback handshakeCallback(&readCallback);
156   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
157   TestSSLServer server(&acceptCallback);
158
159   // Set up SSL context.
160   std::shared_ptr<SSLContext> sslContext(new SSLContext());
161   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
162   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
163   //sslContext->authenticate(true, false);
164
165   // connect
166   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
167                                                  sslContext);
168   socket->open();
169
170   // write()
171   uint8_t buf[128];
172   memset(buf, 'a', sizeof(buf));
173   socket->write(buf, sizeof(buf));
174
175   // read()
176   uint8_t readbuf[128];
177   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
178   EXPECT_EQ(bytesRead, 128);
179   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
180
181   // close()
182   socket->close();
183
184   cerr << "ConnectWriteReadClose test completed" << endl;
185 }
186
187 /**
188  * Negative test for handshakeError().
189  */
190 TEST(AsyncSSLSocketTest, HandshakeError) {
191   // Start listening on a local port
192   WriteCallbackBase writeCallback;
193   ReadCallback readCallback(&writeCallback);
194   HandshakeCallback handshakeCallback(&readCallback);
195   HandshakeErrorCallback acceptCallback(&handshakeCallback);
196   TestSSLServer server(&acceptCallback);
197
198   // Set up SSL context.
199   std::shared_ptr<SSLContext> sslContext(new SSLContext());
200   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
201
202   // connect
203   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
204                                                  sslContext);
205   // read()
206   bool ex = false;
207   try {
208     socket->open();
209
210     uint8_t readbuf[128];
211     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
212     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
213   } catch (AsyncSocketException &e) {
214     ex = true;
215   }
216   EXPECT_TRUE(ex);
217
218   // close()
219   socket->close();
220   cerr << "HandshakeError test completed" << endl;
221 }
222
223 /**
224  * Negative test for readError().
225  */
226 TEST(AsyncSSLSocketTest, ReadError) {
227   // Start listening on a local port
228   WriteCallbackBase writeCallback;
229   ReadErrorCallback readCallback(&writeCallback);
230   HandshakeCallback handshakeCallback(&readCallback);
231   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
232   TestSSLServer server(&acceptCallback);
233
234   // Set up SSL context.
235   std::shared_ptr<SSLContext> sslContext(new SSLContext());
236   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
237
238   // connect
239   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
240                                                  sslContext);
241   socket->open();
242
243   // write something to trigger ssl handshake
244   uint8_t buf[128];
245   memset(buf, 'a', sizeof(buf));
246   socket->write(buf, sizeof(buf));
247
248   socket->close();
249   cerr << "ReadError test completed" << endl;
250 }
251
252 /**
253  * Negative test for writeError().
254  */
255 TEST(AsyncSSLSocketTest, WriteError) {
256   // Start listening on a local port
257   WriteCallbackBase writeCallback;
258   WriteErrorCallback readCallback(&writeCallback);
259   HandshakeCallback handshakeCallback(&readCallback);
260   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
261   TestSSLServer server(&acceptCallback);
262
263   // Set up SSL context.
264   std::shared_ptr<SSLContext> sslContext(new SSLContext());
265   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
266
267   // connect
268   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
269                                                  sslContext);
270   socket->open();
271
272   // write something to trigger ssl handshake
273   uint8_t buf[128];
274   memset(buf, 'a', sizeof(buf));
275   socket->write(buf, sizeof(buf));
276
277   socket->close();
278   cerr << "WriteError test completed" << endl;
279 }
280
281 /**
282  * Test a socket with TCP_NODELAY unset.
283  */
284 TEST(AsyncSSLSocketTest, SocketWithDelay) {
285   // Start listening on a local port
286   WriteCallbackBase writeCallback;
287   ReadCallback readCallback(&writeCallback);
288   HandshakeCallback handshakeCallback(&readCallback);
289   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
290   TestSSLServer server(&acceptCallback);
291
292   // Set up SSL context.
293   std::shared_ptr<SSLContext> sslContext(new SSLContext());
294   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
295
296   // connect
297   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
298                                                  sslContext);
299   socket->open();
300
301   // write()
302   uint8_t buf[128];
303   memset(buf, 'a', sizeof(buf));
304   socket->write(buf, sizeof(buf));
305
306   // read()
307   uint8_t readbuf[128];
308   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
309   EXPECT_EQ(bytesRead, 128);
310   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
311
312   // close()
313   socket->close();
314
315   cerr << "SocketWithDelay test completed" << endl;
316 }
317
318 TEST(AsyncSSLSocketTest, NpnTestOverlap) {
319   EventBase eventBase;
320   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
321   std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
322   int fds[2];
323   getfds(fds);
324   getctx(clientCtx, serverCtx);
325
326   clientCtx->setAdvertisedNextProtocols({"blub","baz"});
327   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
328
329   AsyncSSLSocket::UniquePtr clientSock(
330     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
331   AsyncSSLSocket::UniquePtr serverSock(
332     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
333   NpnClient client(std::move(clientSock));
334   NpnServer server(std::move(serverSock));
335
336   eventBase.loop();
337
338   EXPECT_TRUE(client.nextProtoLength != 0);
339   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
340   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
341                            server.nextProtoLength), 0);
342   string selected((const char*)client.nextProto, client.nextProtoLength);
343   EXPECT_EQ(selected.compare("baz"), 0);
344 }
345
346 TEST(AsyncSSLSocketTest, NpnTestUnset) {
347   // Identical to above test, except that we want unset NPN before
348   // looping.
349   EventBase eventBase;
350   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
351   std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
352   int fds[2];
353   getfds(fds);
354   getctx(clientCtx, serverCtx);
355
356   clientCtx->setAdvertisedNextProtocols({"blub","baz"});
357   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
358
359   AsyncSSLSocket::UniquePtr clientSock(
360     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
361   AsyncSSLSocket::UniquePtr serverSock(
362     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
363
364   // unsetting NPN for any of [client, server] is enought to make NPN not
365   // work
366   clientCtx->unsetNextProtocols();
367
368   NpnClient client(std::move(clientSock));
369   NpnServer server(std::move(serverSock));
370
371   eventBase.loop();
372
373   EXPECT_TRUE(client.nextProtoLength == 0);
374   EXPECT_TRUE(server.nextProtoLength == 0);
375   EXPECT_TRUE(client.nextProto == nullptr);
376   EXPECT_TRUE(server.nextProto == nullptr);
377 }
378
379 TEST(AsyncSSLSocketTest, NpnTestNoOverlap) {
380   EventBase eventBase;
381   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
382   std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
383   int fds[2];
384   getfds(fds);
385   getctx(clientCtx, serverCtx);
386
387   clientCtx->setAdvertisedNextProtocols({"blub"});
388   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
389
390   AsyncSSLSocket::UniquePtr clientSock(
391     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
392   AsyncSSLSocket::UniquePtr serverSock(
393     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
394   NpnClient client(std::move(clientSock));
395   NpnServer server(std::move(serverSock));
396
397   eventBase.loop();
398
399   EXPECT_TRUE(client.nextProtoLength != 0);
400   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
401   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
402                            server.nextProtoLength), 0);
403   string selected((const char*)client.nextProto, client.nextProtoLength);
404   EXPECT_EQ(selected.compare("blub"), 0);
405 }
406
407 TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterHit) {
408   EventBase eventBase;
409   auto clientCtx = std::make_shared<SSLContext>();
410   auto serverCtx = std::make_shared<SSLContext>();
411   int fds[2];
412   getfds(fds);
413   getctx(clientCtx, serverCtx);
414
415   clientCtx->setAdvertisedNextProtocols({"blub"});
416   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
417   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
418
419   AsyncSSLSocket::UniquePtr clientSock(
420     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
421   AsyncSSLSocket::UniquePtr serverSock(
422     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
423   NpnClient client(std::move(clientSock));
424   NpnServer server(std::move(serverSock));
425
426   eventBase.loop();
427
428   EXPECT_TRUE(client.nextProtoLength != 0);
429   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
430   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
431                            server.nextProtoLength), 0);
432   string selected((const char*)client.nextProto, client.nextProtoLength);
433   EXPECT_EQ(selected.compare("ponies"), 0);
434 }
435
436 TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterMiss) {
437   EventBase eventBase;
438   auto clientCtx = std::make_shared<SSLContext>();
439   auto serverCtx = std::make_shared<SSLContext>();
440   int fds[2];
441   getfds(fds);
442   getctx(clientCtx, serverCtx);
443
444   clientCtx->setAdvertisedNextProtocols({"blub"});
445   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
446   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
447
448   AsyncSSLSocket::UniquePtr clientSock(
449     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
450   AsyncSSLSocket::UniquePtr serverSock(
451     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
452   NpnClient client(std::move(clientSock));
453   NpnServer server(std::move(serverSock));
454
455   eventBase.loop();
456
457   EXPECT_TRUE(client.nextProtoLength != 0);
458   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
459   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
460                            server.nextProtoLength), 0);
461   string selected((const char*)client.nextProto, client.nextProtoLength);
462   EXPECT_EQ(selected.compare("blub"), 0);
463 }
464
465 TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
466   // Probability that this test will fail is 2^-64, which could be considered
467   // as negligible.
468   const int kTries = 64;
469
470   std::set<string> selectedProtocols;
471   for (int i = 0; i < kTries; ++i) {
472     EventBase eventBase;
473     std::shared_ptr<SSLContext> clientCtx = std::make_shared<SSLContext>();
474     std::shared_ptr<SSLContext> serverCtx = std::make_shared<SSLContext>();
475     int fds[2];
476     getfds(fds);
477     getctx(clientCtx, serverCtx);
478
479     clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
480     serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}},
481         {1, {"bar"}}});
482
483
484     AsyncSSLSocket::UniquePtr clientSock(
485       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
486     AsyncSSLSocket::UniquePtr serverSock(
487       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
488     NpnClient client(std::move(clientSock));
489     NpnServer server(std::move(serverSock));
490
491     eventBase.loop();
492
493     EXPECT_TRUE(client.nextProtoLength != 0);
494     EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
495     EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
496                              server.nextProtoLength), 0);
497     string selected((const char*)client.nextProto, client.nextProtoLength);
498     selectedProtocols.insert(selected);
499   }
500   EXPECT_EQ(selectedProtocols.size(), 2);
501 }
502
503
504 #ifndef OPENSSL_NO_TLSEXT
505 /**
506  * 1. Client sends TLSEXT_HOSTNAME in client hello.
507  * 2. Server found a match SSL_CTX and use this SSL_CTX to
508  *    continue the SSL handshake.
509  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
510  */
511 TEST(AsyncSSLSocketTest, SNITestMatch) {
512   EventBase eventBase;
513   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
514   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
515   // Use the same SSLContext to continue the handshake after
516   // tlsext_hostname match.
517   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
518   const std::string serverName("xyz.newdev.facebook.com");
519   int fds[2];
520   getfds(fds);
521   getctx(clientCtx, dfServerCtx);
522
523   AsyncSSLSocket::UniquePtr clientSock(
524     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
525   AsyncSSLSocket::UniquePtr serverSock(
526     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
527   SNIClient client(std::move(clientSock));
528   SNIServer server(std::move(serverSock),
529                    dfServerCtx,
530                    hskServerCtx,
531                    serverName);
532
533   eventBase.loop();
534
535   EXPECT_TRUE(client.serverNameMatch);
536   EXPECT_TRUE(server.serverNameMatch);
537 }
538
539 /**
540  * 1. Client sends TLSEXT_HOSTNAME in client hello.
541  * 2. Server cannot find a matching SSL_CTX and continue to use
542  *    the current SSL_CTX to do the handshake.
543  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
544  */
545 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
546   EventBase eventBase;
547   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
548   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
549   // Use the same SSLContext to continue the handshake after
550   // tlsext_hostname match.
551   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
552   const std::string clientRequestingServerName("foo.com");
553   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
554
555   int fds[2];
556   getfds(fds);
557   getctx(clientCtx, dfServerCtx);
558
559   AsyncSSLSocket::UniquePtr clientSock(
560     new AsyncSSLSocket(clientCtx,
561                         &eventBase,
562                         fds[0],
563                         clientRequestingServerName));
564   AsyncSSLSocket::UniquePtr serverSock(
565     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
566   SNIClient client(std::move(clientSock));
567   SNIServer server(std::move(serverSock),
568                    dfServerCtx,
569                    hskServerCtx,
570                    serverExpectedServerName);
571
572   eventBase.loop();
573
574   EXPECT_TRUE(!client.serverNameMatch);
575   EXPECT_TRUE(!server.serverNameMatch);
576 }
577 /**
578  * 1. Client sends TLSEXT_HOSTNAME in client hello.
579  * 2. We then change the serverName.
580  * 3. We expect that we get 'false' as the result for serNameMatch.
581  */
582
583 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
584    EventBase eventBase;
585   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
586   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
587   // Use the same SSLContext to continue the handshake after
588   // tlsext_hostname match.
589   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
590   const std::string serverName("xyz.newdev.facebook.com");
591   int fds[2];
592   getfds(fds);
593   getctx(clientCtx, dfServerCtx);
594
595   AsyncSSLSocket::UniquePtr clientSock(
596     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
597   //Change the server name
598   std::string newName("new.com");
599   clientSock->setServerName(newName);
600   AsyncSSLSocket::UniquePtr serverSock(
601     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
602   SNIClient client(std::move(clientSock));
603   SNIServer server(std::move(serverSock),
604                    dfServerCtx,
605                    hskServerCtx,
606                    serverName);
607
608   eventBase.loop();
609
610   EXPECT_TRUE(!client.serverNameMatch);
611 }
612
613 /**
614  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
615  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
616  */
617 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
618   EventBase eventBase;
619   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
620   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
621   // Use the same SSLContext to continue the handshake after
622   // tlsext_hostname match.
623   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
624   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
625
626   int fds[2];
627   getfds(fds);
628   getctx(clientCtx, dfServerCtx);
629
630   AsyncSSLSocket::UniquePtr clientSock(
631     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
632   AsyncSSLSocket::UniquePtr serverSock(
633     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
634   SNIClient client(std::move(clientSock));
635   SNIServer server(std::move(serverSock),
636                    dfServerCtx,
637                    hskServerCtx,
638                    serverExpectedServerName);
639
640   eventBase.loop();
641
642   EXPECT_TRUE(!client.serverNameMatch);
643   EXPECT_TRUE(!server.serverNameMatch);
644 }
645
646 #endif
647 /**
648  * Test SSL client socket
649  */
650 TEST(AsyncSSLSocketTest, SSLClientTest) {
651   // Start listening on a local port
652   WriteCallbackBase writeCallback;
653   ReadCallback readCallback(&writeCallback);
654   HandshakeCallback handshakeCallback(&readCallback);
655   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
656   TestSSLServer server(&acceptCallback);
657
658   // Set up SSL client
659   EventBase eventBase;
660   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
661                                              1));
662
663   client->connect();
664   EventBaseAborter eba(&eventBase, 3000);
665   eventBase.loop();
666
667   EXPECT_EQ(client->getMiss(), 1);
668   EXPECT_EQ(client->getHit(), 0);
669
670   cerr << "SSLClientTest test completed" << endl;
671 }
672
673
674 /**
675  * Test SSL client socket session re-use
676  */
677 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
678   // Start listening on a local port
679   WriteCallbackBase writeCallback;
680   ReadCallback readCallback(&writeCallback);
681   HandshakeCallback handshakeCallback(&readCallback);
682   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
683   TestSSLServer server(&acceptCallback);
684
685   // Set up SSL client
686   EventBase eventBase;
687   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
688                                              10));
689
690   client->connect();
691   EventBaseAborter eba(&eventBase, 3000);
692   eventBase.loop();
693
694   EXPECT_EQ(client->getMiss(), 1);
695   EXPECT_EQ(client->getHit(), 9);
696
697   cerr << "SSLClientTestReuse test completed" << endl;
698 }
699
700 /**
701  * Test SSL client socket timeout
702  */
703 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
704   // Start listening on a local port
705   EmptyReadCallback readCallback;
706   HandshakeCallback handshakeCallback(&readCallback,
707                                       HandshakeCallback::EXPECT_ERROR);
708   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
709   TestSSLServer server(&acceptCallback);
710
711   // Set up SSL client
712   EventBase eventBase;
713   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
714                                              1, 10));
715   client->connect(true /* write before connect completes */);
716   EventBaseAborter eba(&eventBase, 3000);
717   eventBase.loop();
718
719   usleep(100000);
720   // This is checking that the connectError callback precedes any queued
721   // writeError callbacks.  This matches AsyncSocket's behavior
722   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
723   EXPECT_EQ(client->getErrors(), 1);
724   EXPECT_EQ(client->getMiss(), 0);
725   EXPECT_EQ(client->getHit(), 0);
726
727   cerr << "SSLClientTimeoutTest test completed" << endl;
728 }
729
730
731 /**
732  * Test SSL server async cache
733  */
734 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
735   // Start listening on a local port
736   WriteCallbackBase writeCallback;
737   ReadCallback readCallback(&writeCallback);
738   HandshakeCallback handshakeCallback(&readCallback);
739   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
740   TestSSLAsyncCacheServer server(&acceptCallback);
741
742   // Set up SSL client
743   EventBase eventBase;
744   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
745                                              10, 500));
746
747   client->connect();
748   EventBaseAborter eba(&eventBase, 3000);
749   eventBase.loop();
750
751   EXPECT_EQ(server.getAsyncCallbacks(), 18);
752   EXPECT_EQ(server.getAsyncLookups(), 9);
753   EXPECT_EQ(client->getMiss(), 10);
754   EXPECT_EQ(client->getHit(), 0);
755
756   cerr << "SSLServerAsyncCacheTest test completed" << endl;
757 }
758
759
760 /**
761  * Test SSL server accept timeout with cache path
762  */
763 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
764   // Start listening on a local port
765   WriteCallbackBase writeCallback;
766   ReadCallback readCallback(&writeCallback);
767   EmptyReadCallback clientReadCallback;
768   HandshakeCallback handshakeCallback(&readCallback);
769   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
770   TestSSLAsyncCacheServer server(&acceptCallback);
771
772   // Set up SSL client
773   EventBase eventBase;
774   // only do a TCP connect
775   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
776   sock->connect(nullptr, server.getAddress());
777   clientReadCallback.tcpSocket_ = sock;
778   sock->setReadCB(&clientReadCallback);
779
780   EventBaseAborter eba(&eventBase, 3000);
781   eventBase.loop();
782
783   EXPECT_EQ(readCallback.state, STATE_WAITING);
784
785   cerr << "SSLServerTimeoutTest test completed" << endl;
786 }
787
788 /**
789  * Test SSL server accept timeout with cache path
790  */
791 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
792   // Start listening on a local port
793   WriteCallbackBase writeCallback;
794   ReadCallback readCallback(&writeCallback);
795   HandshakeCallback handshakeCallback(&readCallback);
796   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
797   TestSSLAsyncCacheServer server(&acceptCallback);
798
799   // Set up SSL client
800   EventBase eventBase;
801   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
802                                              2));
803
804   client->connect();
805   EventBaseAborter eba(&eventBase, 3000);
806   eventBase.loop();
807
808   EXPECT_EQ(server.getAsyncCallbacks(), 1);
809   EXPECT_EQ(server.getAsyncLookups(), 1);
810   EXPECT_EQ(client->getErrors(), 1);
811   EXPECT_EQ(client->getMiss(), 1);
812   EXPECT_EQ(client->getHit(), 0);
813
814   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
815 }
816
817 /**
818  * Test SSL server accept timeout with cache path
819  */
820 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
821   // Start listening on a local port
822   WriteCallbackBase writeCallback;
823   ReadCallback readCallback(&writeCallback);
824   HandshakeCallback handshakeCallback(&readCallback,
825                                       HandshakeCallback::EXPECT_ERROR);
826   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
827   TestSSLAsyncCacheServer server(&acceptCallback, 500);
828
829   // Set up SSL client
830   EventBase eventBase;
831   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
832                                              2, 100));
833
834   client->connect();
835   EventBaseAborter eba(&eventBase, 3000);
836   eventBase.loop();
837
838   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
839       handshakeCallback.closeSocket();});
840   // give time for the cache lookup to come back and find it closed
841   usleep(500000);
842
843   EXPECT_EQ(server.getAsyncCallbacks(), 1);
844   EXPECT_EQ(server.getAsyncLookups(), 1);
845   EXPECT_EQ(client->getErrors(), 1);
846   EXPECT_EQ(client->getMiss(), 1);
847   EXPECT_EQ(client->getHit(), 0);
848
849   cerr << "SSLServerCacheCloseTest test completed" << endl;
850 }
851
852 /**
853  * Verify Client Ciphers obtained using SSL MSG Callback.
854  */
855 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
856   EventBase eventBase;
857   auto clientCtx = std::make_shared<SSLContext>();
858   auto serverCtx = std::make_shared<SSLContext>();
859   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
860   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
861   serverCtx->loadPrivateKey(testKey);
862   serverCtx->loadCertificate(testCert);
863   serverCtx->loadTrustedCertificates(testCA);
864   serverCtx->loadClientCAList(testCA);
865
866   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
867   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
868   clientCtx->loadPrivateKey(testKey);
869   clientCtx->loadCertificate(testCert);
870   clientCtx->loadTrustedCertificates(testCA);
871
872   int fds[2];
873   getfds(fds);
874
875   AsyncSSLSocket::UniquePtr clientSock(
876       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
877   AsyncSSLSocket::UniquePtr serverSock(
878       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
879
880   SSLHandshakeClient client(std::move(clientSock), true, true);
881   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
882
883   eventBase.loop();
884
885   EXPECT_EQ(server.clientCiphers_,
886             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
887   EXPECT_TRUE(client.handshakeVerify_);
888   EXPECT_TRUE(client.handshakeSuccess_);
889   EXPECT_TRUE(!client.handshakeError_);
890   EXPECT_TRUE(server.handshakeVerify_);
891   EXPECT_TRUE(server.handshakeSuccess_);
892   EXPECT_TRUE(!server.handshakeError_);
893 }
894
895 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
896   EventBase eventBase;
897   auto ctx = std::make_shared<SSLContext>();
898
899   int fds[2];
900   getfds(fds);
901
902   int bufLen = 42;
903   uint8_t majorVersion = 18;
904   uint8_t minorVersion = 25;
905
906   // Create callback buf
907   auto buf = IOBuf::create(bufLen);
908   buf->append(bufLen);
909   folly::io::RWPrivateCursor cursor(buf.get());
910   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
911   cursor.write<uint16_t>(0);
912   cursor.write<uint8_t>(38);
913   cursor.write<uint8_t>(majorVersion);
914   cursor.write<uint8_t>(minorVersion);
915   cursor.skip(32);
916   cursor.write<uint32_t>(0);
917
918   SSL* ssl = ctx->createSSL();
919   SCOPE_EXIT { SSL_free(ssl); };
920   AsyncSSLSocket::UniquePtr sock(
921       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
922   sock->enableClientHelloParsing();
923
924   // Test client hello parsing in one packet
925   AsyncSSLSocket::clientHelloParsingCallback(
926       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
927   buf.reset();
928
929   auto parsedClientHello = sock->getClientHelloInfo();
930   EXPECT_TRUE(parsedClientHello != nullptr);
931   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
932   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
933 }
934
935 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
936   EventBase eventBase;
937   auto ctx = std::make_shared<SSLContext>();
938
939   int fds[2];
940   getfds(fds);
941
942   int bufLen = 42;
943   uint8_t majorVersion = 18;
944   uint8_t minorVersion = 25;
945
946   // Create callback buf
947   auto buf = IOBuf::create(bufLen);
948   buf->append(bufLen);
949   folly::io::RWPrivateCursor cursor(buf.get());
950   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
951   cursor.write<uint16_t>(0);
952   cursor.write<uint8_t>(38);
953   cursor.write<uint8_t>(majorVersion);
954   cursor.write<uint8_t>(minorVersion);
955   cursor.skip(32);
956   cursor.write<uint32_t>(0);
957
958   SSL* ssl = ctx->createSSL();
959   SCOPE_EXIT { SSL_free(ssl); };
960   AsyncSSLSocket::UniquePtr sock(
961       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
962   sock->enableClientHelloParsing();
963
964   // Test parsing with two packets with first packet size < 3
965   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
966   AsyncSSLSocket::clientHelloParsingCallback(
967       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
968       ssl, sock.get());
969   bufCopy.reset();
970   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
971   AsyncSSLSocket::clientHelloParsingCallback(
972       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
973       ssl, sock.get());
974   bufCopy.reset();
975
976   auto parsedClientHello = sock->getClientHelloInfo();
977   EXPECT_TRUE(parsedClientHello != nullptr);
978   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
979   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
980 }
981
982 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
983   EventBase eventBase;
984   auto ctx = std::make_shared<SSLContext>();
985
986   int fds[2];
987   getfds(fds);
988
989   int bufLen = 42;
990   uint8_t majorVersion = 18;
991   uint8_t minorVersion = 25;
992
993   // Create callback buf
994   auto buf = IOBuf::create(bufLen);
995   buf->append(bufLen);
996   folly::io::RWPrivateCursor cursor(buf.get());
997   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
998   cursor.write<uint16_t>(0);
999   cursor.write<uint8_t>(38);
1000   cursor.write<uint8_t>(majorVersion);
1001   cursor.write<uint8_t>(minorVersion);
1002   cursor.skip(32);
1003   cursor.write<uint32_t>(0);
1004
1005   SSL* ssl = ctx->createSSL();
1006   SCOPE_EXIT { SSL_free(ssl); };
1007   AsyncSSLSocket::UniquePtr sock(
1008       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1009   sock->enableClientHelloParsing();
1010
1011   // Test parsing with multiple small packets
1012   for (uint64_t i = 0; i < buf->length(); i += 3) {
1013     auto bufCopy = folly::IOBuf::copyBuffer(
1014         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1015     AsyncSSLSocket::clientHelloParsingCallback(
1016         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1017         ssl, sock.get());
1018     bufCopy.reset();
1019   }
1020
1021   auto parsedClientHello = sock->getClientHelloInfo();
1022   EXPECT_TRUE(parsedClientHello != nullptr);
1023   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1024   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1025 }
1026
1027 /**
1028  * Verify sucessful behavior of SSL certificate validation.
1029  */
1030 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1031   EventBase eventBase;
1032   auto clientCtx = std::make_shared<SSLContext>();
1033   auto dfServerCtx = std::make_shared<SSLContext>();
1034
1035   int fds[2];
1036   getfds(fds);
1037   getctx(clientCtx, dfServerCtx);
1038
1039   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1040   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1041
1042   AsyncSSLSocket::UniquePtr clientSock(
1043     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1044   AsyncSSLSocket::UniquePtr serverSock(
1045     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1046
1047   SSLHandshakeClient client(std::move(clientSock), true, true);
1048   clientCtx->loadTrustedCertificates(testCA);
1049
1050   SSLHandshakeServer server(std::move(serverSock), true, true);
1051
1052   eventBase.loop();
1053
1054   EXPECT_TRUE(client.handshakeVerify_);
1055   EXPECT_TRUE(client.handshakeSuccess_);
1056   EXPECT_TRUE(!client.handshakeError_);
1057   EXPECT_TRUE(!server.handshakeVerify_);
1058   EXPECT_TRUE(server.handshakeSuccess_);
1059   EXPECT_TRUE(!server.handshakeError_);
1060 }
1061
1062 /**
1063  * Verify that the client's verification callback is able to fail SSL
1064  * connection establishment.
1065  */
1066 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1067   EventBase eventBase;
1068   auto clientCtx = std::make_shared<SSLContext>();
1069   auto dfServerCtx = std::make_shared<SSLContext>();
1070
1071   int fds[2];
1072   getfds(fds);
1073   getctx(clientCtx, dfServerCtx);
1074
1075   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1076   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1077
1078   AsyncSSLSocket::UniquePtr clientSock(
1079     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1080   AsyncSSLSocket::UniquePtr serverSock(
1081     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1082
1083   SSLHandshakeClient client(std::move(clientSock), true, false);
1084   clientCtx->loadTrustedCertificates(testCA);
1085
1086   SSLHandshakeServer server(std::move(serverSock), true, true);
1087
1088   eventBase.loop();
1089
1090   EXPECT_TRUE(client.handshakeVerify_);
1091   EXPECT_TRUE(!client.handshakeSuccess_);
1092   EXPECT_TRUE(client.handshakeError_);
1093   EXPECT_TRUE(!server.handshakeVerify_);
1094   EXPECT_TRUE(!server.handshakeSuccess_);
1095   EXPECT_TRUE(server.handshakeError_);
1096 }
1097
1098 /**
1099  * Verify that the options in SSLContext can be overridden in
1100  * sslConnect/Accept.i.e specifying that no validation should be performed
1101  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1102  * the validation callback.
1103  */
1104 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1105   EventBase eventBase;
1106   auto clientCtx = std::make_shared<SSLContext>();
1107   auto dfServerCtx = std::make_shared<SSLContext>();
1108
1109   int fds[2];
1110   getfds(fds);
1111   getctx(clientCtx, dfServerCtx);
1112
1113   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1114   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1115
1116   AsyncSSLSocket::UniquePtr clientSock(
1117     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1118   AsyncSSLSocket::UniquePtr serverSock(
1119     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1120
1121   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1122   clientCtx->loadTrustedCertificates(testCA);
1123
1124   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1125
1126   eventBase.loop();
1127
1128   EXPECT_TRUE(!client.handshakeVerify_);
1129   EXPECT_TRUE(client.handshakeSuccess_);
1130   EXPECT_TRUE(!client.handshakeError_);
1131   EXPECT_TRUE(!server.handshakeVerify_);
1132   EXPECT_TRUE(server.handshakeSuccess_);
1133   EXPECT_TRUE(!server.handshakeError_);
1134 }
1135
1136 /**
1137  * Verify that the options in SSLContext can be overridden in
1138  * sslConnect/Accept. Enable verification even if context says otherwise.
1139  * Test requireClientCert with client cert
1140  */
1141 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1142   EventBase eventBase;
1143   auto clientCtx = std::make_shared<SSLContext>();
1144   auto serverCtx = std::make_shared<SSLContext>();
1145   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1146   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1147   serverCtx->loadPrivateKey(testKey);
1148   serverCtx->loadCertificate(testCert);
1149   serverCtx->loadTrustedCertificates(testCA);
1150   serverCtx->loadClientCAList(testCA);
1151
1152   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1153   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1154   clientCtx->loadPrivateKey(testKey);
1155   clientCtx->loadCertificate(testCert);
1156   clientCtx->loadTrustedCertificates(testCA);
1157
1158   int fds[2];
1159   getfds(fds);
1160
1161   AsyncSSLSocket::UniquePtr clientSock(
1162       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1163   AsyncSSLSocket::UniquePtr serverSock(
1164       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1165
1166   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1167   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1168
1169   eventBase.loop();
1170
1171   EXPECT_TRUE(client.handshakeVerify_);
1172   EXPECT_TRUE(client.handshakeSuccess_);
1173   EXPECT_FALSE(client.handshakeError_);
1174   EXPECT_TRUE(server.handshakeVerify_);
1175   EXPECT_TRUE(server.handshakeSuccess_);
1176   EXPECT_FALSE(server.handshakeError_);
1177 }
1178
1179 /**
1180  * Verify that the client's verification callback is able to override
1181  * the preverification failure and allow a successful connection.
1182  */
1183 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1184   EventBase eventBase;
1185   auto clientCtx = std::make_shared<SSLContext>();
1186   auto dfServerCtx = std::make_shared<SSLContext>();
1187
1188   int fds[2];
1189   getfds(fds);
1190   getctx(clientCtx, dfServerCtx);
1191
1192   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1193   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1194
1195   AsyncSSLSocket::UniquePtr clientSock(
1196     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1197   AsyncSSLSocket::UniquePtr serverSock(
1198     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1199
1200   SSLHandshakeClient client(std::move(clientSock), false, true);
1201   SSLHandshakeServer server(std::move(serverSock), true, true);
1202
1203   eventBase.loop();
1204
1205   EXPECT_TRUE(client.handshakeVerify_);
1206   EXPECT_TRUE(client.handshakeSuccess_);
1207   EXPECT_TRUE(!client.handshakeError_);
1208   EXPECT_TRUE(!server.handshakeVerify_);
1209   EXPECT_TRUE(server.handshakeSuccess_);
1210   EXPECT_TRUE(!server.handshakeError_);
1211 }
1212
1213 /**
1214  * Verify that specifying that no validation should be performed allows an
1215  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1216  * callback.
1217  */
1218 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1219   EventBase eventBase;
1220   auto clientCtx = std::make_shared<SSLContext>();
1221   auto dfServerCtx = std::make_shared<SSLContext>();
1222
1223   int fds[2];
1224   getfds(fds);
1225   getctx(clientCtx, dfServerCtx);
1226
1227   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1228   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1229
1230   AsyncSSLSocket::UniquePtr clientSock(
1231     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1232   AsyncSSLSocket::UniquePtr serverSock(
1233     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1234
1235   SSLHandshakeClient client(std::move(clientSock), false, false);
1236   SSLHandshakeServer server(std::move(serverSock), false, false);
1237
1238   eventBase.loop();
1239
1240   EXPECT_TRUE(!client.handshakeVerify_);
1241   EXPECT_TRUE(client.handshakeSuccess_);
1242   EXPECT_TRUE(!client.handshakeError_);
1243   EXPECT_TRUE(!server.handshakeVerify_);
1244   EXPECT_TRUE(server.handshakeSuccess_);
1245   EXPECT_TRUE(!server.handshakeError_);
1246 }
1247
1248 /**
1249  * Test requireClientCert with client cert
1250  */
1251 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1252   EventBase eventBase;
1253   auto clientCtx = std::make_shared<SSLContext>();
1254   auto serverCtx = std::make_shared<SSLContext>();
1255   serverCtx->setVerificationOption(
1256       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1257   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1258   serverCtx->loadPrivateKey(testKey);
1259   serverCtx->loadCertificate(testCert);
1260   serverCtx->loadTrustedCertificates(testCA);
1261   serverCtx->loadClientCAList(testCA);
1262
1263   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1264   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1265   clientCtx->loadPrivateKey(testKey);
1266   clientCtx->loadCertificate(testCert);
1267   clientCtx->loadTrustedCertificates(testCA);
1268
1269   int fds[2];
1270   getfds(fds);
1271
1272   AsyncSSLSocket::UniquePtr clientSock(
1273       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1274   AsyncSSLSocket::UniquePtr serverSock(
1275       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1276
1277   SSLHandshakeClient client(std::move(clientSock), true, true);
1278   SSLHandshakeServer server(std::move(serverSock), true, true);
1279
1280   eventBase.loop();
1281
1282   EXPECT_TRUE(client.handshakeVerify_);
1283   EXPECT_TRUE(client.handshakeSuccess_);
1284   EXPECT_FALSE(client.handshakeError_);
1285   EXPECT_TRUE(server.handshakeVerify_);
1286   EXPECT_TRUE(server.handshakeSuccess_);
1287   EXPECT_FALSE(server.handshakeError_);
1288 }
1289
1290
1291 /**
1292  * Test requireClientCert with no client cert
1293  */
1294 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1295   EventBase eventBase;
1296   auto clientCtx = std::make_shared<SSLContext>();
1297   auto serverCtx = std::make_shared<SSLContext>();
1298   serverCtx->setVerificationOption(
1299       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1300   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1301   serverCtx->loadPrivateKey(testKey);
1302   serverCtx->loadCertificate(testCert);
1303   serverCtx->loadTrustedCertificates(testCA);
1304   serverCtx->loadClientCAList(testCA);
1305   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1306   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1307
1308   int fds[2];
1309   getfds(fds);
1310
1311   AsyncSSLSocket::UniquePtr clientSock(
1312       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1313   AsyncSSLSocket::UniquePtr serverSock(
1314       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1315
1316   SSLHandshakeClient client(std::move(clientSock), false, false);
1317   SSLHandshakeServer server(std::move(serverSock), false, false);
1318
1319   eventBase.loop();
1320
1321   EXPECT_FALSE(server.handshakeVerify_);
1322   EXPECT_FALSE(server.handshakeSuccess_);
1323   EXPECT_TRUE(server.handshakeError_);
1324 }
1325
1326 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1327   EventBase eb;
1328
1329   // Set up SSL context.
1330   auto sslContext = std::make_shared<SSLContext>();
1331   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1332
1333   // create SSL socket
1334   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1335
1336   EXPECT_EQ(1500, socket->getMinWriteSize());
1337
1338   socket->setMinWriteSize(0);
1339   EXPECT_EQ(0, socket->getMinWriteSize());
1340   socket->setMinWriteSize(50000);
1341   EXPECT_EQ(50000, socket->getMinWriteSize());
1342 }
1343
1344 class ReadCallbackTerminator : public ReadCallback {
1345  public:
1346   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1347       : ReadCallback(wcb)
1348       , base_(base) {}
1349
1350   // Do not write data back, terminate the loop.
1351   void readDataAvailable(size_t len) noexcept override {
1352     std::cerr << "readDataAvailable, len " << len << std::endl;
1353
1354     currentBuffer.length = len;
1355
1356     buffers.push_back(currentBuffer);
1357     currentBuffer.reset();
1358     state = STATE_SUCCEEDED;
1359
1360     socket_->setReadCB(nullptr);
1361     base_->terminateLoopSoon();
1362   }
1363  private:
1364   EventBase* base_;
1365 };
1366
1367
1368 /**
1369  * Test a full unencrypted codepath
1370  */
1371 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1372   EventBase base;
1373
1374   auto clientCtx = std::make_shared<folly::SSLContext>();
1375   auto serverCtx = std::make_shared<folly::SSLContext>();
1376   int fds[2];
1377   getfds(fds);
1378   getctx(clientCtx, serverCtx);
1379   auto client = AsyncSSLSocket::newSocket(
1380                   clientCtx, &base, fds[0], false, true);
1381   auto server = AsyncSSLSocket::newSocket(
1382                   serverCtx, &base, fds[1], true, true);
1383
1384   ReadCallbackTerminator readCallback(&base, nullptr);
1385   server->setReadCB(&readCallback);
1386   readCallback.setSocket(server);
1387
1388   uint8_t buf[128];
1389   memset(buf, 'a', sizeof(buf));
1390   client->write(nullptr, buf, sizeof(buf));
1391
1392   // Check that bytes are unencrypted
1393   char c;
1394   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1395   EXPECT_EQ('a', c);
1396
1397   EventBaseAborter eba(&base, 3000);
1398   base.loop();
1399
1400   EXPECT_EQ(1, readCallback.buffers.size());
1401   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1402
1403   server->setReadCB(&readCallback);
1404
1405   // Unencrypted
1406   server->sslAccept(nullptr);
1407   client->sslConn(nullptr);
1408
1409   // Do NOT wait for handshake, writing should be queued and happen after
1410
1411   client->write(nullptr, buf, sizeof(buf));
1412
1413   // Check that bytes are *not* unencrypted
1414   char c2;
1415   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1416   EXPECT_NE('a', c2);
1417
1418
1419   base.loop();
1420
1421   EXPECT_EQ(2, readCallback.buffers.size());
1422   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1423 }
1424
1425 } // namespace
1426
1427 ///////////////////////////////////////////////////////////////////////////
1428 // init_unit_test_suite
1429 ///////////////////////////////////////////////////////////////////////////
1430 namespace {
1431 struct Initializer {
1432   Initializer() {
1433     signal(SIGPIPE, SIG_IGN);
1434   }
1435 };
1436 Initializer initializer;
1437 } // anonymous