2 * Copyright 2017 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 #include <folly/io/Cursor.h>
17 #include <folly/io/async/AsyncSSLSocket.h>
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/EventBase.h>
20 #include <folly/portability/GMock.h>
21 #include <folly/portability/GTest.h>
27 using namespace testing;
31 class MockAsyncSSLSocket : public AsyncSSLSocket{
33 static std::shared_ptr<MockAsyncSSLSocket> newSocket(
34 const std::shared_ptr<SSLContext>& ctx,
36 auto sock = std::shared_ptr<MockAsyncSSLSocket>(
37 new MockAsyncSSLSocket(ctx, evb),
39 sock->ssl_ = SSL_new(ctx->getSSLCtx());
40 SSL_set_fd(sock->ssl_, -1);
44 // Fake constructor sets the state to established without call to connect
46 MockAsyncSSLSocket(const std::shared_ptr<SSLContext>& ctx,
48 : AsyncSocket(evb), AsyncSSLSocket(ctx, evb) {
49 state_ = AsyncSocket::StateEnum::ESTABLISHED;
50 sslState_ = AsyncSSLSocket::SSLStateEnum::STATE_ESTABLISHED;
53 // mock the calls to SSL_write to see the buffer length and contents
54 MOCK_METHOD3(sslWriteImpl, int(SSL *ssl, const void *buf, int n));
56 // mock the calls to getRawBytesWritten()
57 MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
59 // public wrapper for protected interface
60 WriteResult testPerformWrite(
64 uint32_t* countWritten,
65 uint32_t* partialWritten) {
66 return performWrite(vec, count, flags, countWritten, partialWritten);
69 void checkEor(size_t appEor, size_t rawEor) {
70 EXPECT_EQ(appEor, appEorByteNo_);
71 EXPECT_EQ(rawEor, minEorRawByteNo_);
74 void setAppBytesWritten(size_t n) {
79 class AsyncSSLSocketWriteTest : public testing::Test {
81 AsyncSSLSocketWriteTest() :
82 sslContext_(new SSLContext()),
83 sock_(MockAsyncSSLSocket::newSocket(sslContext_, &eventBase_)) {
84 for (int i = 0; i < 500; i++) {
85 memcpy(source_ + i * 26, "abcdefghijklmnopqrstuvwxyz", 26);
89 // Make an iovec containing chunks of the reference text with requested sizes
91 std::unique_ptr<iovec[]> makeVec(std::vector<uint32_t> sizes) {
92 std::unique_ptr<iovec[]> vec(new iovec[sizes.size()]);
95 for (auto size: sizes) {
96 vec[i].iov_base = (void *)(source_ + pos);
97 vec[i++].iov_len = size;
103 // Verify that the given buf/pos matches the reference text
104 void verifyVec(const void *buf, int n, int pos) {
105 ASSERT_EQ(memcmp(source_ + pos, buf, n), 0);
108 // Update a vec on partial write
109 void consumeVec(iovec *vec, uint32_t countWritten, uint32_t partialWritten) {
110 vec[countWritten].iov_base =
111 ((char *)vec[countWritten].iov_base) + partialWritten;
112 vec[countWritten].iov_len -= partialWritten;
115 EventBase eventBase_;
116 std::shared_ptr<SSLContext> sslContext_;
117 std::shared_ptr<MockAsyncSSLSocket> sock_;
118 char source_[26 * 500];
122 // The entire vec fits in one packet
123 TEST_F(AsyncSSLSocketWriteTest, write_coalescing1) {
125 auto vec = makeVec({3, 3, 3});
126 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 9))
127 .WillOnce(Invoke([this] (SSL *, const void *buf, int m) {
128 verifyVec(buf, m, 0);
130 uint32_t countWritten = 0;
131 uint32_t partialWritten = 0;
132 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
134 EXPECT_EQ(countWritten, n);
135 EXPECT_EQ(partialWritten, 0);
138 // First packet is full, second two go in one packet
139 TEST_F(AsyncSSLSocketWriteTest, write_coalescing2) {
141 auto vec = makeVec({1500, 3, 3});
143 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
144 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
145 verifyVec(buf, m, pos);
148 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
149 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
150 verifyVec(buf, m, pos);
153 uint32_t countWritten = 0;
154 uint32_t partialWritten = 0;
155 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
157 EXPECT_EQ(countWritten, n);
158 EXPECT_EQ(partialWritten, 0);
161 // Two exactly full packets (coalesce ends midway through second chunk)
162 TEST_F(AsyncSSLSocketWriteTest, write_coalescing3) {
164 auto vec = makeVec({1000, 1000, 1000});
166 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
168 .WillRepeatedly(Invoke([this, &pos] (SSL *, const void *buf, int m) {
169 verifyVec(buf, m, pos);
172 uint32_t countWritten = 0;
173 uint32_t partialWritten = 0;
174 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
176 EXPECT_EQ(countWritten, n);
177 EXPECT_EQ(partialWritten, 0);
180 // Partial write success midway through a coalesced vec
181 TEST_F(AsyncSSLSocketWriteTest, write_coalescing4) {
183 auto vec = makeVec({300, 300, 300, 300, 300});
185 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
186 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
187 verifyVec(buf, m, pos);
189 return 1000; /* 500 bytes "pending" */ }));
190 uint32_t countWritten = 0;
191 uint32_t partialWritten = 0;
192 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
194 EXPECT_EQ(countWritten, 3);
195 EXPECT_EQ(partialWritten, 100);
196 consumeVec(vec.get(), countWritten, partialWritten);
197 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
198 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
199 verifyVec(buf, m, pos);
202 sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
204 &countWritten, &partialWritten);
205 EXPECT_EQ(countWritten, 2);
206 EXPECT_EQ(partialWritten, 0);
209 // coalesce ends exactly on a buffer boundary
210 TEST_F(AsyncSSLSocketWriteTest, write_coalescing5) {
212 auto vec = makeVec({1000, 500, 500});
214 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
215 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
216 verifyVec(buf, m, pos);
219 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
220 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
221 verifyVec(buf, m, pos);
224 uint32_t countWritten = 0;
225 uint32_t partialWritten = 0;
226 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
228 EXPECT_EQ(countWritten, 3);
229 EXPECT_EQ(partialWritten, 0);
232 // partial write midway through first chunk
233 TEST_F(AsyncSSLSocketWriteTest, write_coalescing6) {
235 auto vec = makeVec({1000, 500});
237 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
238 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
239 verifyVec(buf, m, pos);
242 uint32_t countWritten = 0;
243 uint32_t partialWritten = 0;
244 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
246 EXPECT_EQ(countWritten, 0);
247 EXPECT_EQ(partialWritten, 700);
248 consumeVec(vec.get(), countWritten, partialWritten);
249 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 800))
250 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
251 verifyVec(buf, m, pos);
254 sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
256 &countWritten, &partialWritten);
257 EXPECT_EQ(countWritten, 2);
258 EXPECT_EQ(partialWritten, 0);
261 // Repeat coalescing2 with WriteFlags::EOR
262 TEST_F(AsyncSSLSocketWriteTest, write_with_eor1) {
264 auto vec = makeVec({1500, 3, 3});
266 const size_t initAppBytesWritten = 500;
267 const size_t appEor = initAppBytesWritten + 1506;
269 sock_->setAppBytesWritten(initAppBytesWritten);
270 EXPECT_FALSE(sock_->isEorTrackingEnabled());
271 sock_->setEorTracking(true);
272 EXPECT_TRUE(sock_->isEorTrackingEnabled());
274 EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
275 // rawBytesWritten after writting initAppBytesWritten + 1500
276 // + some random SSL overhead
277 .WillOnce(Return(3600u))
278 // rawBytesWritten after writting last 6 bytes
279 // + some random SSL overhead
280 .WillOnce(Return(3728u));
281 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
282 .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int m) {
283 // the first 1500 does not have the EOR byte
284 sock_->checkEor(0, 0);
285 verifyVec(buf, m, pos);
288 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
289 .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int m) {
290 sock_->checkEor(appEor, 3600 + m);
291 verifyVec(buf, m, pos);
295 uint32_t countWritten = 0;
296 uint32_t partialWritten = 0;
297 sock_->testPerformWrite(vec.get(), n , WriteFlags::EOR,
298 &countWritten, &partialWritten);
299 EXPECT_EQ(countWritten, n);
300 EXPECT_EQ(partialWritten, 0);
301 sock_->checkEor(0, 0);
304 // coalescing with left over at the last chunk
305 // WriteFlags::EOR turned on
306 TEST_F(AsyncSSLSocketWriteTest, write_with_eor2) {
308 auto vec = makeVec({600, 600, 600});
310 const size_t initAppBytesWritten = 500;
311 const size_t appEor = initAppBytesWritten + 1800;
313 sock_->setAppBytesWritten(initAppBytesWritten);
314 sock_->setEorTracking(true);
316 EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
317 // rawBytesWritten after writting initAppBytesWritten + 1500 bytes
318 // + some random SSL overhead
319 .WillOnce(Return(3600))
320 // rawBytesWritten after writting last 300 bytes
321 // + some random SSL overhead
322 .WillOnce(Return(4100));
323 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
324 .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int m) {
325 // the first 1500 does not have the EOR byte
326 sock_->checkEor(0, 0);
327 verifyVec(buf, m, pos);
330 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 300))
331 .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int m) {
332 sock_->checkEor(appEor, 3600 + m);
333 verifyVec(buf, m, pos);
337 uint32_t countWritten = 0;
338 uint32_t partialWritten = 0;
339 sock_->testPerformWrite(vec.get(), n, WriteFlags::EOR,
340 &countWritten, &partialWritten);
341 EXPECT_EQ(countWritten, n);
342 EXPECT_EQ(partialWritten, 0);
343 sock_->checkEor(0, 0);
346 // WriteFlags::EOR set
348 // Partial write at 1000-th byte
349 TEST_F(AsyncSSLSocketWriteTest, write_with_eor3) {
351 auto vec = makeVec({1600});
353 static constexpr size_t initAppBytesWritten = 500;
354 static constexpr size_t appEor = initAppBytesWritten + 1600;
356 sock_->setAppBytesWritten(initAppBytesWritten);
357 sock_->setEorTracking(true);
359 EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
360 // rawBytesWritten after the initAppBytesWritten
361 // + some random SSL overhead
362 .WillOnce(Return(2000))
363 // rawBytesWritten after the initAppBytesWritten + 1000 (with 100 overhead)
364 // + some random SSL overhead
365 .WillOnce(Return(3100));
366 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1600))
367 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
368 sock_->checkEor(appEor, 2000 + m);
369 verifyVec(buf, m, pos);
373 uint32_t countWritten = 0;
374 uint32_t partialWritten = 0;
375 sock_->testPerformWrite(vec.get(), n, WriteFlags::EOR,
376 &countWritten, &partialWritten);
377 EXPECT_EQ(countWritten, 0);
378 EXPECT_EQ(partialWritten, 1000);
379 sock_->checkEor(appEor, 2000 + 1600);
380 consumeVec(vec.get(), countWritten, partialWritten);
382 EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
383 .WillOnce(Return(3100))
384 .WillOnce(Return(3800));
385 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 600))
386 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
387 sock_->checkEor(appEor, 3100 + m);
388 verifyVec(buf, m, pos);
391 sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
393 &countWritten, &partialWritten);
394 EXPECT_EQ(countWritten, n);
395 EXPECT_EQ(partialWritten, 0);
396 sock_->checkEor(0, 0);