d874509b8cd65d49ed21bea244eb5f2a2c0f87d1
[folly.git] / folly / io / async / test / SSLSessionTest.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/test/AsyncSSLSocketTest.h>
18 #include <folly/portability/GTest.h>
19 #include <folly/portability/Sockets.h>
20 #include <folly/ssl/SSLSession.h>
21
22 #include <memory>
23
24 using namespace std;
25 using namespace testing;
26 using folly::ssl::SSLSession;
27
28 namespace folly {
29
30 void getfds(int fds[2]) {
31   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
32     LOG(ERROR) << "failed to create socketpair: " << strerror(errno);
33   }
34   for (int idx = 0; idx < 2; ++idx) {
35     int flags = fcntl(fds[idx], F_GETFL, 0);
36     if (flags == -1) {
37       LOG(ERROR) << "failed to get flags for socket " << idx << ": "
38                  << strerror(errno);
39     }
40     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
41       LOG(ERROR) << "failed to put socket " << idx
42                  << " in non-blocking mode: " << strerror(errno);
43     }
44   }
45 }
46
47 void getctx(
48     std::shared_ptr<folly::SSLContext> clientCtx,
49     std::shared_ptr<folly::SSLContext> serverCtx) {
50   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
51
52   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
53   serverCtx->loadCertificate(kTestCert);
54   serverCtx->loadPrivateKey(kTestKey);
55 }
56
57 class SSLSessionTest : public testing::Test {
58  public:
59   void SetUp() override {
60     clientCtx.reset(new folly::SSLContext());
61     dfServerCtx.reset(new folly::SSLContext());
62     hskServerCtx.reset(new folly::SSLContext());
63     serverName = "xyz.newdev.facebook.com";
64     getctx(clientCtx, dfServerCtx);
65   }
66
67   void TearDown() override {}
68
69   folly::EventBase eventBase;
70   std::shared_ptr<SSLContext> clientCtx;
71   std::shared_ptr<SSLContext> dfServerCtx;
72   // Use the same SSLContext to continue the handshake after
73   // tlsext_hostname match.
74   std::shared_ptr<SSLContext> hskServerCtx;
75   std::string serverName;
76 };
77
78 /**
79  * 1. Client sends TLSEXT_HOSTNAME in client hello.
80  * 2. Server found a match SSL_CTX and use this SSL_CTX to
81  *    continue the SSL handshake.
82  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
83  */
84 TEST_F(SSLSessionTest, BasicTest) {
85   std::unique_ptr<SSLSession> sess;
86
87   {
88     int fds[2];
89     getfds(fds);
90     AsyncSSLSocket::UniquePtr clientSock(
91         new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
92     auto clientPtr = clientSock.get();
93     AsyncSSLSocket::UniquePtr serverSock(
94         new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
95     SSLHandshakeClient client(std::move(clientSock), false, false);
96     SSLHandshakeServerParseClientHello server(
97         std::move(serverSock), false, false);
98
99     eventBase.loop();
100     ASSERT_TRUE(client.handshakeSuccess_);
101
102     sess = std::make_unique<SSLSession>(clientPtr->getSSLSession());
103     ASSERT_NE(sess.get(), nullptr);
104   }
105
106   {
107     int fds[2];
108     getfds(fds);
109     AsyncSSLSocket::UniquePtr clientSock(
110         new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
111     auto clientPtr = clientSock.get();
112     clientSock->setSSLSession(sess->getRawSSLSessionDangerous(), true);
113     AsyncSSLSocket::UniquePtr serverSock(
114         new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
115     SSLHandshakeClient client(std::move(clientSock), false, false);
116     SSLHandshakeServerParseClientHello server(
117         std::move(serverSock), false, false);
118
119     eventBase.loop();
120     ASSERT_TRUE(client.handshakeSuccess_);
121     ASSERT_TRUE(clientPtr->getSSLSessionReused());
122   }
123 }
124 TEST_F(SSLSessionTest, SerializeDeserializeTest) {
125   std::string sessiondata;
126
127   {
128     int fds[2];
129     getfds(fds);
130     AsyncSSLSocket::UniquePtr clientSock(
131         new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
132     auto clientPtr = clientSock.get();
133     AsyncSSLSocket::UniquePtr serverSock(
134         new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
135     SSLHandshakeClient client(std::move(clientSock), false, false);
136     SSLHandshakeServerParseClientHello server(
137         std::move(serverSock), false, false);
138
139     eventBase.loop();
140     ASSERT_TRUE(client.handshakeSuccess_);
141
142     std::unique_ptr<SSLSession> sess =
143         std::make_unique<SSLSession>(clientPtr->getSSLSession());
144     sessiondata = sess->serialize();
145     ASSERT_TRUE(!sessiondata.empty());
146   }
147
148   {
149     int fds[2];
150     getfds(fds);
151     AsyncSSLSocket::UniquePtr clientSock(
152         new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
153     auto clientPtr = clientSock.get();
154     std::unique_ptr<SSLSession> sess =
155         std::make_unique<SSLSession>(sessiondata);
156     ASSERT_NE(sess.get(), nullptr);
157     clientSock->setSSLSession(sess->getRawSSLSessionDangerous(), true);
158     AsyncSSLSocket::UniquePtr serverSock(
159         new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
160     SSLHandshakeClient client(std::move(clientSock), false, false);
161     SSLHandshakeServerParseClientHello server(
162         std::move(serverSock), false, false);
163
164     eventBase.loop();
165     ASSERT_TRUE(client.handshakeSuccess_);
166     ASSERT_TRUE(clientPtr->getSSLSessionReused());
167   }
168 }
169
170 TEST_F(SSLSessionTest, GetSessionID) {
171   int fds[2];
172   getfds(fds);
173   AsyncSSLSocket::UniquePtr clientSock(
174       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
175   auto clientPtr = clientSock.get();
176   AsyncSSLSocket::UniquePtr serverSock(
177       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
178   SSLHandshakeClient client(std::move(clientSock), false, false);
179   SSLHandshakeServerParseClientHello server(
180       std::move(serverSock), false, false);
181
182   eventBase.loop();
183   ASSERT_TRUE(client.handshakeSuccess_);
184
185   std::unique_ptr<SSLSession> sess =
186       std::make_unique<SSLSession>(clientPtr->getSSLSession());
187   ASSERT_NE(sess, nullptr);
188   auto sessID = sess->getSessionID();
189   ASSERT_GE(sessID.length(), 0);
190 }
191 }