13270a69ab44533b81ed4eec749f912613c3bd5b
[folly.git] / folly / io / async / test / AsyncSSLSocketWriteTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/Foreach.h>
17 #include <folly/io/Cursor.h>
18 #include <folly/io/async/AsyncSSLSocket.h>
19 #include <folly/io/async/AsyncSocket.h>
20 #include <folly/io/async/EventBase.h>
21 #include <folly/portability/GMock.h>
22
23 #include <gtest/gtest.h>
24 #include <string>
25 #include <vector>
26
27 using std::string;
28 using namespace testing;
29
30 namespace folly {
31
32 class MockAsyncSSLSocket : public AsyncSSLSocket{
33  public:
34   static std::shared_ptr<MockAsyncSSLSocket> newSocket(
35     const std::shared_ptr<SSLContext>& ctx,
36     EventBase* evb) {
37     auto sock = std::shared_ptr<MockAsyncSSLSocket>(
38       new MockAsyncSSLSocket(ctx, evb),
39       Destructor());
40     sock->ssl_ = SSL_new(ctx->getSSLCtx());
41     SSL_set_fd(sock->ssl_, -1);
42     return sock;
43   }
44
45   // Fake constructor sets the state to established without call to connect
46   // or accept
47   MockAsyncSSLSocket(const std::shared_ptr<SSLContext>& ctx,
48                       EventBase* evb)
49       : AsyncSocket(evb), AsyncSSLSocket(ctx, evb) {
50     state_ = AsyncSocket::StateEnum::ESTABLISHED;
51     sslState_ = AsyncSSLSocket::SSLStateEnum::STATE_ESTABLISHED;
52   }
53
54   // mock the calls to SSL_write to see the buffer length and contents
55   MOCK_METHOD3(sslWriteImpl, int(SSL *ssl, const void *buf, int n));
56
57   // mock the calls to getRawBytesWritten()
58   MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
59
60   // public wrapper for protected interface
61   WriteResult testPerformWrite(
62       const iovec* vec,
63       uint32_t count,
64       WriteFlags flags,
65       uint32_t* countWritten,
66       uint32_t* partialWritten) {
67     return performWrite(vec, count, flags, countWritten, partialWritten);
68   }
69
70   void checkEor(size_t appEor, size_t rawEor) {
71     EXPECT_EQ(appEor, appEorByteNo_);
72     EXPECT_EQ(rawEor, minEorRawByteNo_);
73   }
74
75   void setAppBytesWritten(size_t n) {
76     appBytesWritten_ = n;
77   }
78 };
79
80 class AsyncSSLSocketWriteTest : public testing::Test {
81  public:
82   AsyncSSLSocketWriteTest() :
83       sslContext_(new SSLContext()),
84       sock_(MockAsyncSSLSocket::newSocket(sslContext_, &eventBase_)) {
85     for (int i = 0; i < 500; i++) {
86       memcpy(source_ + i * 26, "abcdefghijklmnopqrstuvwxyz", 26);
87     }
88   }
89
90   // Make an iovec containing chunks of the reference text with requested sizes
91   // for each chunk
92   std::unique_ptr<iovec[]> makeVec(std::vector<uint32_t> sizes) {
93     std::unique_ptr<iovec[]> vec(new iovec[sizes.size()]);
94     int i = 0;
95     int pos = 0;
96     for (auto size: sizes) {
97       vec[i].iov_base = (void *)(source_ + pos);
98       vec[i++].iov_len = size;
99       pos += size;
100     }
101     return vec;
102   }
103
104   // Verify that the given buf/pos matches the reference text
105   void verifyVec(const void *buf, int n, int pos) {
106     ASSERT_EQ(memcmp(source_ + pos, buf, n), 0);
107   }
108
109   // Update a vec on partial write
110   void consumeVec(iovec *vec, uint32_t countWritten, uint32_t partialWritten) {
111     vec[countWritten].iov_base =
112       ((char *)vec[countWritten].iov_base) + partialWritten;
113     vec[countWritten].iov_len -= partialWritten;
114   }
115
116   EventBase eventBase_;
117   std::shared_ptr<SSLContext> sslContext_;
118   std::shared_ptr<MockAsyncSSLSocket> sock_;
119   char source_[26 * 500];
120 };
121
122
123 // The entire vec fits in one packet
124 TEST_F(AsyncSSLSocketWriteTest, write_coalescing1) {
125   int n = 3;
126   auto vec = makeVec({3, 3, 3});
127   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 9))
128     .WillOnce(Invoke([this] (SSL *, const void *buf, int n) {
129           verifyVec(buf, n, 0);
130           return 9; }));
131   uint32_t countWritten = 0;
132   uint32_t partialWritten = 0;
133   sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
134                           &partialWritten);
135   EXPECT_EQ(countWritten, n);
136   EXPECT_EQ(partialWritten, 0);
137 }
138
139 // First packet is full, second two go in one packet
140 TEST_F(AsyncSSLSocketWriteTest, write_coalescing2) {
141   int n = 3;
142   auto vec = makeVec({1500, 3, 3});
143   int pos = 0;
144   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
145     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
146           verifyVec(buf, n, pos);
147           pos += n;
148           return n; }));
149   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
150     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
151           verifyVec(buf, n, pos);
152           pos += n;
153           return n; }));
154   uint32_t countWritten = 0;
155   uint32_t partialWritten = 0;
156   sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
157                           &partialWritten);
158   EXPECT_EQ(countWritten, n);
159   EXPECT_EQ(partialWritten, 0);
160 }
161
162 // Two exactly full packets (coalesce ends midway through second chunk)
163 TEST_F(AsyncSSLSocketWriteTest, write_coalescing3) {
164   int n = 3;
165   auto vec = makeVec({1000, 1000, 1000});
166   int pos = 0;
167   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
168     .Times(2)
169     .WillRepeatedly(Invoke([this, &pos] (SSL *, const void *buf, int n) {
170           verifyVec(buf, n, pos);
171           pos += n;
172           return n; }));
173   uint32_t countWritten = 0;
174   uint32_t partialWritten = 0;
175   sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
176                           &partialWritten);
177   EXPECT_EQ(countWritten, n);
178   EXPECT_EQ(partialWritten, 0);
179 }
180
181 // Partial write success midway through a coalesced vec
182 TEST_F(AsyncSSLSocketWriteTest, write_coalescing4) {
183   int n = 5;
184   auto vec = makeVec({300, 300, 300, 300, 300});
185   int pos = 0;
186   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
187     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
188           verifyVec(buf, n, pos);
189           pos += 1000;
190           return 1000; /* 500 bytes "pending" */ }));
191   uint32_t countWritten = 0;
192   uint32_t partialWritten = 0;
193   sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
194                           &partialWritten);
195   EXPECT_EQ(countWritten, 3);
196   EXPECT_EQ(partialWritten, 100);
197   consumeVec(vec.get(), countWritten, partialWritten);
198   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
199     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
200           verifyVec(buf, n, pos);
201           pos += n;
202           return 500; }));
203   sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
204                           WriteFlags::NONE,
205                           &countWritten, &partialWritten);
206   EXPECT_EQ(countWritten, 2);
207   EXPECT_EQ(partialWritten, 0);
208 }
209
210 // coalesce ends exactly on a buffer boundary
211 TEST_F(AsyncSSLSocketWriteTest, write_coalescing5) {
212   int n = 3;
213   auto vec = makeVec({1000, 500, 500});
214   int pos = 0;
215   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
216     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
217           verifyVec(buf, n, pos);
218           pos += n;
219           return n; }));
220   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
221     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
222           verifyVec(buf, n, pos);
223           pos += n;
224           return n; }));
225   uint32_t countWritten = 0;
226   uint32_t partialWritten = 0;
227   sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
228                           &partialWritten);
229   EXPECT_EQ(countWritten, 3);
230   EXPECT_EQ(partialWritten, 0);
231 }
232
233 // partial write midway through first chunk
234 TEST_F(AsyncSSLSocketWriteTest, write_coalescing6) {
235   int n = 2;
236   auto vec = makeVec({1000, 500});
237   int pos = 0;
238   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
239     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
240           verifyVec(buf, n, pos);
241           pos += 700;
242           return 700; }));
243   uint32_t countWritten = 0;
244   uint32_t partialWritten = 0;
245   sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
246                           &partialWritten);
247   EXPECT_EQ(countWritten, 0);
248   EXPECT_EQ(partialWritten, 700);
249   consumeVec(vec.get(), countWritten, partialWritten);
250   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 800))
251     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
252           verifyVec(buf, n, pos);
253           pos += n;
254           return n; }));
255   sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
256                           WriteFlags::NONE,
257                           &countWritten, &partialWritten);
258   EXPECT_EQ(countWritten, 2);
259   EXPECT_EQ(partialWritten, 0);
260 }
261
262 // Repeat coalescing2 with WriteFlags::EOR
263 TEST_F(AsyncSSLSocketWriteTest, write_with_eor1) {
264   int n = 3;
265   auto vec = makeVec({1500, 3, 3});
266   int pos = 0;
267   const size_t initAppBytesWritten = 500;
268   const size_t appEor = initAppBytesWritten + 1506;
269
270   sock_->setAppBytesWritten(initAppBytesWritten);
271   EXPECT_FALSE(sock_->isEorTrackingEnabled());
272   sock_->setEorTracking(true);
273   EXPECT_TRUE(sock_->isEorTrackingEnabled());
274
275   EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
276     // rawBytesWritten after writting initAppBytesWritten + 1500
277     // + some random SSL overhead
278     .WillOnce(Return(3600))
279     // rawBytesWritten after writting last 6 bytes
280     // + some random SSL overhead
281     .WillOnce(Return(3728));
282   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
283     .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
284           // the first 1500 does not have the EOR byte
285           sock_->checkEor(0, 0);
286           verifyVec(buf, n, pos);
287           pos += n;
288           return n; }));
289   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
290     .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
291           sock_->checkEor(appEor, 3600 + n);
292           verifyVec(buf, n, pos);
293           pos += n;
294           return n; }));
295
296   uint32_t countWritten = 0;
297   uint32_t partialWritten = 0;
298   sock_->testPerformWrite(vec.get(), n , WriteFlags::EOR,
299                           &countWritten, &partialWritten);
300   EXPECT_EQ(countWritten, n);
301   EXPECT_EQ(partialWritten, 0);
302   sock_->checkEor(0, 0);
303 }
304
305 // coalescing with left over at the last chunk
306 // WriteFlags::EOR turned on
307 TEST_F(AsyncSSLSocketWriteTest, write_with_eor2) {
308   int n = 3;
309   auto vec = makeVec({600, 600, 600});
310   int pos = 0;
311   const size_t initAppBytesWritten = 500;
312   const size_t appEor = initAppBytesWritten + 1800;
313
314   sock_->setAppBytesWritten(initAppBytesWritten);
315   sock_->setEorTracking(true);
316
317   EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
318     // rawBytesWritten after writting initAppBytesWritten +  1500 bytes
319     // + some random SSL overhead
320     .WillOnce(Return(3600))
321     // rawBytesWritten after writting last 300 bytes
322     // + some random SSL overhead
323     .WillOnce(Return(4100));
324   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
325     .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
326           // the first 1500 does not have the EOR byte
327           sock_->checkEor(0, 0);
328           verifyVec(buf, n, pos);
329           pos += n;
330           return n; }));
331   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 300))
332     .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
333           sock_->checkEor(appEor, 3600 + n);
334           verifyVec(buf, n, pos);
335           pos += n;
336           return n; }));
337
338   uint32_t countWritten = 0;
339   uint32_t partialWritten = 0;
340   sock_->testPerformWrite(vec.get(), n, WriteFlags::EOR,
341                           &countWritten, &partialWritten);
342   EXPECT_EQ(countWritten, n);
343   EXPECT_EQ(partialWritten, 0);
344   sock_->checkEor(0, 0);
345 }
346
347 // WriteFlags::EOR set
348 // One buf in iovec
349 // Partial write at 1000-th byte
350 TEST_F(AsyncSSLSocketWriteTest, write_with_eor3) {
351   int n = 1;
352   auto vec = makeVec({1600});
353   int pos = 0;
354   static constexpr size_t initAppBytesWritten = 500;
355   static constexpr size_t appEor = initAppBytesWritten + 1600;
356
357   sock_->setAppBytesWritten(initAppBytesWritten);
358   sock_->setEorTracking(true);
359
360   EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
361     // rawBytesWritten after the initAppBytesWritten
362     // + some random SSL overhead
363     .WillOnce(Return(2000))
364     // rawBytesWritten after the initAppBytesWritten + 1000 (with 100 overhead)
365     // + some random SSL overhead
366     .WillOnce(Return(3100));
367   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1600))
368     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
369           sock_->checkEor(appEor, 2000 + n);
370           verifyVec(buf, n, pos);
371           pos += 1000;
372           return 1000; }));
373
374   uint32_t countWritten = 0;
375   uint32_t partialWritten = 0;
376   sock_->testPerformWrite(vec.get(), n, WriteFlags::EOR,
377                           &countWritten, &partialWritten);
378   EXPECT_EQ(countWritten, 0);
379   EXPECT_EQ(partialWritten, 1000);
380   sock_->checkEor(appEor, 2000 + 1600);
381   consumeVec(vec.get(), countWritten, partialWritten);
382
383   EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
384     .WillOnce(Return(3100))
385     .WillOnce(Return(3800));
386   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 600))
387     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
388           sock_->checkEor(appEor, 3100 + n);
389           verifyVec(buf, n, pos);
390           pos += n;
391           return n; }));
392   sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
393                           WriteFlags::EOR,
394                           &countWritten, &partialWritten);
395   EXPECT_EQ(countWritten, n);
396   EXPECT_EQ(partialWritten, 0);
397   sock_->checkEor(0, 0);
398 }
399
400 }