2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 #include <folly/Foreach.h>
17 #include <folly/io/Cursor.h>
18 #include <folly/io/async/AsyncSSLSocket.h>
19 #include <folly/io/async/AsyncSocket.h>
20 #include <folly/io/async/EventBase.h>
22 #include <gtest/gtest.h>
23 #include <gmock/gmock.h>
28 using namespace testing;
32 class MockAsyncSSLSocket : public AsyncSSLSocket{
34 static std::shared_ptr<MockAsyncSSLSocket> newSocket(
35 const std::shared_ptr<SSLContext>& ctx,
37 auto sock = std::shared_ptr<MockAsyncSSLSocket>(
38 new MockAsyncSSLSocket(ctx, evb),
40 sock->ssl_ = SSL_new(ctx->getSSLCtx());
41 SSL_set_fd(sock->ssl_, -1);
45 // Fake constructor sets the state to established without call to connect
47 MockAsyncSSLSocket(const std::shared_ptr<SSLContext>& ctx,
49 : AsyncSocket(evb), AsyncSSLSocket(ctx, evb) {
50 state_ = AsyncSocket::StateEnum::ESTABLISHED;
51 sslState_ = AsyncSSLSocket::SSLStateEnum::STATE_ESTABLISHED;
54 // mock the calls to SSL_write to see the buffer length and contents
55 MOCK_METHOD3(sslWriteImpl, int(SSL *ssl, const void *buf, int n));
57 // mock the calls to getRawBytesWritten()
58 MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
60 // public wrapper for protected interface
61 WriteResult testPerformWrite(
65 uint32_t* countWritten,
66 uint32_t* partialWritten) {
67 return performWrite(vec, count, flags, countWritten, partialWritten);
70 void checkEor(size_t appEor, size_t rawEor) {
71 EXPECT_EQ(appEor, appEorByteNo_);
72 EXPECT_EQ(rawEor, minEorRawByteNo_);
75 void setAppBytesWritten(size_t n) {
80 class AsyncSSLSocketWriteTest : public testing::Test {
82 AsyncSSLSocketWriteTest() :
83 sslContext_(new SSLContext()),
84 sock_(MockAsyncSSLSocket::newSocket(sslContext_, &eventBase_)) {
85 for (int i = 0; i < 500; i++) {
86 memcpy(source_ + i * 26, "abcdefghijklmnopqrstuvwxyz", 26);
90 // Make an iovec containing chunks of the reference text with requested sizes
92 std::unique_ptr<iovec[]> makeVec(std::vector<uint32_t> sizes) {
93 std::unique_ptr<iovec[]> vec(new iovec[sizes.size()]);
96 for (auto size: sizes) {
97 vec[i].iov_base = (void *)(source_ + pos);
98 vec[i++].iov_len = size;
104 // Verify that the given buf/pos matches the reference text
105 void verifyVec(const void *buf, int n, int pos) {
106 ASSERT_EQ(memcmp(source_ + pos, buf, n), 0);
109 // Update a vec on partial write
110 void consumeVec(iovec *vec, uint32_t countWritten, uint32_t partialWritten) {
111 vec[countWritten].iov_base =
112 ((char *)vec[countWritten].iov_base) + partialWritten;
113 vec[countWritten].iov_len -= partialWritten;
116 EventBase eventBase_;
117 std::shared_ptr<SSLContext> sslContext_;
118 std::shared_ptr<MockAsyncSSLSocket> sock_;
119 char source_[26 * 500];
123 // The entire vec fits in one packet
124 TEST_F(AsyncSSLSocketWriteTest, write_coalescing1) {
126 auto vec = makeVec({3, 3, 3});
127 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 9))
128 .WillOnce(Invoke([this] (SSL *, const void *buf, int n) {
129 verifyVec(buf, n, 0);
131 uint32_t countWritten = 0;
132 uint32_t partialWritten = 0;
133 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
135 EXPECT_EQ(countWritten, n);
136 EXPECT_EQ(partialWritten, 0);
139 // First packet is full, second two go in one packet
140 TEST_F(AsyncSSLSocketWriteTest, write_coalescing2) {
142 auto vec = makeVec({1500, 3, 3});
144 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
145 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
146 verifyVec(buf, n, pos);
149 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
150 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
151 verifyVec(buf, n, pos);
154 uint32_t countWritten = 0;
155 uint32_t partialWritten = 0;
156 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
158 EXPECT_EQ(countWritten, n);
159 EXPECT_EQ(partialWritten, 0);
162 // Two exactly full packets (coalesce ends midway through second chunk)
163 TEST_F(AsyncSSLSocketWriteTest, write_coalescing3) {
165 auto vec = makeVec({1000, 1000, 1000});
167 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
169 .WillRepeatedly(Invoke([this, &pos] (SSL *, const void *buf, int n) {
170 verifyVec(buf, n, pos);
173 uint32_t countWritten = 0;
174 uint32_t partialWritten = 0;
175 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
177 EXPECT_EQ(countWritten, n);
178 EXPECT_EQ(partialWritten, 0);
181 // Partial write success midway through a coalesced vec
182 TEST_F(AsyncSSLSocketWriteTest, write_coalescing4) {
184 auto vec = makeVec({300, 300, 300, 300, 300});
186 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
187 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
188 verifyVec(buf, n, pos);
190 return 1000; /* 500 bytes "pending" */ }));
191 uint32_t countWritten = 0;
192 uint32_t partialWritten = 0;
193 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
195 EXPECT_EQ(countWritten, 3);
196 EXPECT_EQ(partialWritten, 100);
197 consumeVec(vec.get(), countWritten, partialWritten);
198 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
199 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
200 verifyVec(buf, n, pos);
203 sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
205 &countWritten, &partialWritten);
206 EXPECT_EQ(countWritten, 2);
207 EXPECT_EQ(partialWritten, 0);
210 // coalesce ends exactly on a buffer boundary
211 TEST_F(AsyncSSLSocketWriteTest, write_coalescing5) {
213 auto vec = makeVec({1000, 500, 500});
215 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
216 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
217 verifyVec(buf, n, pos);
220 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
221 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
222 verifyVec(buf, n, pos);
225 uint32_t countWritten = 0;
226 uint32_t partialWritten = 0;
227 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
229 EXPECT_EQ(countWritten, 3);
230 EXPECT_EQ(partialWritten, 0);
233 // partial write midway through first chunk
234 TEST_F(AsyncSSLSocketWriteTest, write_coalescing6) {
236 auto vec = makeVec({1000, 500});
238 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
239 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
240 verifyVec(buf, n, pos);
243 uint32_t countWritten = 0;
244 uint32_t partialWritten = 0;
245 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
247 EXPECT_EQ(countWritten, 0);
248 EXPECT_EQ(partialWritten, 700);
249 consumeVec(vec.get(), countWritten, partialWritten);
250 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 800))
251 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
252 verifyVec(buf, n, pos);
255 sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
257 &countWritten, &partialWritten);
258 EXPECT_EQ(countWritten, 2);
259 EXPECT_EQ(partialWritten, 0);
262 // Repeat coalescing2 with WriteFlags::EOR
263 TEST_F(AsyncSSLSocketWriteTest, write_with_eor1) {
265 auto vec = makeVec({1500, 3, 3});
267 const size_t initAppBytesWritten = 500;
268 const size_t appEor = initAppBytesWritten + 1506;
270 sock_->setAppBytesWritten(initAppBytesWritten);
271 EXPECT_FALSE(sock_->isEorTrackingEnabled());
272 sock_->setEorTracking(true);
273 EXPECT_TRUE(sock_->isEorTrackingEnabled());
275 EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
276 // rawBytesWritten after writting initAppBytesWritten + 1500
277 // + some random SSL overhead
278 .WillOnce(Return(3600))
279 // rawBytesWritten after writting last 6 bytes
280 // + some random SSL overhead
281 .WillOnce(Return(3728));
282 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
283 .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
284 // the first 1500 does not have the EOR byte
285 sock_->checkEor(0, 0);
286 verifyVec(buf, n, pos);
289 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
290 .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
291 sock_->checkEor(appEor, 3600 + n);
292 verifyVec(buf, n, pos);
296 uint32_t countWritten = 0;
297 uint32_t partialWritten = 0;
298 sock_->testPerformWrite(vec.get(), n , WriteFlags::EOR,
299 &countWritten, &partialWritten);
300 EXPECT_EQ(countWritten, n);
301 EXPECT_EQ(partialWritten, 0);
302 sock_->checkEor(0, 0);
305 // coalescing with left over at the last chunk
306 // WriteFlags::EOR turned on
307 TEST_F(AsyncSSLSocketWriteTest, write_with_eor2) {
309 auto vec = makeVec({600, 600, 600});
311 const size_t initAppBytesWritten = 500;
312 const size_t appEor = initAppBytesWritten + 1800;
314 sock_->setAppBytesWritten(initAppBytesWritten);
315 sock_->setEorTracking(true);
317 EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
318 // rawBytesWritten after writting initAppBytesWritten + 1500 bytes
319 // + some random SSL overhead
320 .WillOnce(Return(3600))
321 // rawBytesWritten after writting last 300 bytes
322 // + some random SSL overhead
323 .WillOnce(Return(4100));
324 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
325 .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
326 // the first 1500 does not have the EOR byte
327 sock_->checkEor(0, 0);
328 verifyVec(buf, n, pos);
331 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 300))
332 .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
333 sock_->checkEor(appEor, 3600 + n);
334 verifyVec(buf, n, pos);
338 uint32_t countWritten = 0;
339 uint32_t partialWritten = 0;
340 sock_->testPerformWrite(vec.get(), n, WriteFlags::EOR,
341 &countWritten, &partialWritten);
342 EXPECT_EQ(countWritten, n);
343 EXPECT_EQ(partialWritten, 0);
344 sock_->checkEor(0, 0);
347 // WriteFlags::EOR set
349 // Partial write at 1000-th byte
350 TEST_F(AsyncSSLSocketWriteTest, write_with_eor3) {
352 auto vec = makeVec({1600});
354 static constexpr size_t initAppBytesWritten = 500;
355 static constexpr size_t appEor = initAppBytesWritten + 1600;
357 sock_->setAppBytesWritten(initAppBytesWritten);
358 sock_->setEorTracking(true);
360 EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
361 // rawBytesWritten after the initAppBytesWritten
362 // + some random SSL overhead
363 .WillOnce(Return(2000))
364 // rawBytesWritten after the initAppBytesWritten + 1000 (with 100 overhead)
365 // + some random SSL overhead
366 .WillOnce(Return(3100));
367 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1600))
368 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
369 sock_->checkEor(appEor, 2000 + n);
370 verifyVec(buf, n, pos);
374 uint32_t countWritten = 0;
375 uint32_t partialWritten = 0;
376 sock_->testPerformWrite(vec.get(), n, WriteFlags::EOR,
377 &countWritten, &partialWritten);
378 EXPECT_EQ(countWritten, 0);
379 EXPECT_EQ(partialWritten, 1000);
380 sock_->checkEor(appEor, 2000 + 1600);
381 consumeVec(vec.get(), countWritten, partialWritten);
383 EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
384 .WillOnce(Return(3100))
385 .WillOnce(Return(3800));
386 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 600))
387 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
388 sock_->checkEor(appEor, 3100 + n);
389 verifyVec(buf, n, pos);
392 sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
394 &countWritten, &partialWritten);
395 EXPECT_EQ(countWritten, n);
396 EXPECT_EQ(partialWritten, 0);
397 sock_->checkEor(0, 0);