d47af702d428a17bc9e265630f17c267c343945b
[folly.git] / folly / io / async / test / AsyncSSLSocketWriteTest.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/Cursor.h>
17 #include <folly/io/async/AsyncSSLSocket.h>
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/EventBase.h>
20 #include <folly/portability/GMock.h>
21 #include <folly/portability/GTest.h>
22
23 #include <string>
24 #include <vector>
25
26 using std::string;
27 using namespace testing;
28
29 namespace folly {
30
31 class MockAsyncSSLSocket : public AsyncSSLSocket{
32  public:
33   static std::shared_ptr<MockAsyncSSLSocket> newSocket(
34     const std::shared_ptr<SSLContext>& ctx,
35     EventBase* evb) {
36     auto sock = std::shared_ptr<MockAsyncSSLSocket>(
37       new MockAsyncSSLSocket(ctx, evb),
38       Destructor());
39     sock->ssl_ = SSL_new(ctx->getSSLCtx());
40     SSL_set_fd(sock->ssl_, -1);
41     return sock;
42   }
43
44   // Fake constructor sets the state to established without call to connect
45   // or accept
46   MockAsyncSSLSocket(const std::shared_ptr<SSLContext>& ctx,
47                       EventBase* evb)
48       : AsyncSocket(evb), AsyncSSLSocket(ctx, evb) {
49     state_ = AsyncSocket::StateEnum::ESTABLISHED;
50     sslState_ = AsyncSSLSocket::SSLStateEnum::STATE_ESTABLISHED;
51   }
52
53   // mock the calls to SSL_write to see the buffer length and contents
54   MOCK_METHOD3(sslWriteImpl, int(SSL *ssl, const void *buf, int n));
55
56   // mock the calls to getRawBytesWritten()
57   MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
58
59   // public wrapper for protected interface
60   WriteResult testPerformWrite(
61       const iovec* vec,
62       uint32_t count,
63       WriteFlags flags,
64       uint32_t* countWritten,
65       uint32_t* partialWritten) {
66     return performWrite(vec, count, flags, countWritten, partialWritten);
67   }
68
69   void checkEor(size_t appEor, size_t rawEor) {
70     EXPECT_EQ(appEor, appEorByteNo_);
71     EXPECT_EQ(rawEor, minEorRawByteNo_);
72   }
73
74   void setAppBytesWritten(size_t n) {
75     appBytesWritten_ = n;
76   }
77 };
78
79 class AsyncSSLSocketWriteTest : public testing::Test {
80  public:
81   AsyncSSLSocketWriteTest() :
82       sslContext_(new SSLContext()),
83       sock_(MockAsyncSSLSocket::newSocket(sslContext_, &eventBase_)) {
84     for (int i = 0; i < 500; i++) {
85       memcpy(source_ + i * 26, "abcdefghijklmnopqrstuvwxyz", 26);
86     }
87   }
88
89   // Make an iovec containing chunks of the reference text with requested sizes
90   // for each chunk
91   std::unique_ptr<iovec[]> makeVec(std::vector<uint32_t> sizes) {
92     std::unique_ptr<iovec[]> vec(new iovec[sizes.size()]);
93     int i = 0;
94     int pos = 0;
95     for (auto size: sizes) {
96       vec[i].iov_base = (void *)(source_ + pos);
97       vec[i++].iov_len = size;
98       pos += size;
99     }
100     return vec;
101   }
102
103   // Verify that the given buf/pos matches the reference text
104   void verifyVec(const void *buf, int n, int pos) {
105     ASSERT_EQ(memcmp(source_ + pos, buf, n), 0);
106   }
107
108   // Update a vec on partial write
109   void consumeVec(iovec *vec, uint32_t countWritten, uint32_t partialWritten) {
110     vec[countWritten].iov_base =
111       ((char *)vec[countWritten].iov_base) + partialWritten;
112     vec[countWritten].iov_len -= partialWritten;
113   }
114
115   EventBase eventBase_;
116   std::shared_ptr<SSLContext> sslContext_;
117   std::shared_ptr<MockAsyncSSLSocket> sock_;
118   char source_[26 * 500];
119 };
120
121
122 // The entire vec fits in one packet
123 TEST_F(AsyncSSLSocketWriteTest, write_coalescing1) {
124   int n = 3;
125   auto vec = makeVec({3, 3, 3});
126   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 9))
127     .WillOnce(Invoke([this] (SSL *, const void *buf, int m) {
128           verifyVec(buf, m, 0);
129           return 9; }));
130   uint32_t countWritten = 0;
131   uint32_t partialWritten = 0;
132   sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
133                           &partialWritten);
134   EXPECT_EQ(countWritten, n);
135   EXPECT_EQ(partialWritten, 0);
136 }
137
138 // First packet is full, second two go in one packet
139 TEST_F(AsyncSSLSocketWriteTest, write_coalescing2) {
140   int n = 3;
141   auto vec = makeVec({1500, 3, 3});
142   int pos = 0;
143   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
144     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
145           verifyVec(buf, m, pos);
146           pos += m;
147           return m; }));
148   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
149     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
150           verifyVec(buf, m, pos);
151           pos += m;
152           return m; }));
153   uint32_t countWritten = 0;
154   uint32_t partialWritten = 0;
155   sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
156                           &partialWritten);
157   EXPECT_EQ(countWritten, n);
158   EXPECT_EQ(partialWritten, 0);
159 }
160
161 // Two exactly full packets (coalesce ends midway through second chunk)
162 TEST_F(AsyncSSLSocketWriteTest, write_coalescing3) {
163   int n = 3;
164   auto vec = makeVec({1000, 1000, 1000});
165   int pos = 0;
166   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
167     .Times(2)
168     .WillRepeatedly(Invoke([this, &pos] (SSL *, const void *buf, int m) {
169           verifyVec(buf, m, pos);
170           pos += m;
171           return m; }));
172   uint32_t countWritten = 0;
173   uint32_t partialWritten = 0;
174   sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
175                           &partialWritten);
176   EXPECT_EQ(countWritten, n);
177   EXPECT_EQ(partialWritten, 0);
178 }
179
180 // Partial write success midway through a coalesced vec
181 TEST_F(AsyncSSLSocketWriteTest, write_coalescing4) {
182   int n = 5;
183   auto vec = makeVec({300, 300, 300, 300, 300});
184   int pos = 0;
185   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
186     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
187           verifyVec(buf, m, pos);
188           pos += 1000;
189           return 1000; /* 500 bytes "pending" */ }));
190   uint32_t countWritten = 0;
191   uint32_t partialWritten = 0;
192   sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
193                           &partialWritten);
194   EXPECT_EQ(countWritten, 3);
195   EXPECT_EQ(partialWritten, 100);
196   consumeVec(vec.get(), countWritten, partialWritten);
197   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
198     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
199           verifyVec(buf, m, pos);
200           pos += m;
201           return 500; }));
202   sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
203                           WriteFlags::NONE,
204                           &countWritten, &partialWritten);
205   EXPECT_EQ(countWritten, 2);
206   EXPECT_EQ(partialWritten, 0);
207 }
208
209 // coalesce ends exactly on a buffer boundary
210 TEST_F(AsyncSSLSocketWriteTest, write_coalescing5) {
211   int n = 3;
212   auto vec = makeVec({1000, 500, 500});
213   int pos = 0;
214   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
215     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
216           verifyVec(buf, m, pos);
217           pos += m;
218           return m; }));
219   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
220     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
221           verifyVec(buf, m, pos);
222           pos += m;
223           return m; }));
224   uint32_t countWritten = 0;
225   uint32_t partialWritten = 0;
226   sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
227                           &partialWritten);
228   EXPECT_EQ(countWritten, 3);
229   EXPECT_EQ(partialWritten, 0);
230 }
231
232 // partial write midway through first chunk
233 TEST_F(AsyncSSLSocketWriteTest, write_coalescing6) {
234   int n = 2;
235   auto vec = makeVec({1000, 500});
236   int pos = 0;
237   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
238     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
239           verifyVec(buf, m, pos);
240           pos += 700;
241           return 700; }));
242   uint32_t countWritten = 0;
243   uint32_t partialWritten = 0;
244   sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
245                           &partialWritten);
246   EXPECT_EQ(countWritten, 0);
247   EXPECT_EQ(partialWritten, 700);
248   consumeVec(vec.get(), countWritten, partialWritten);
249   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 800))
250     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
251           verifyVec(buf, m, pos);
252           pos += m;
253           return m; }));
254   sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
255                           WriteFlags::NONE,
256                           &countWritten, &partialWritten);
257   EXPECT_EQ(countWritten, 2);
258   EXPECT_EQ(partialWritten, 0);
259 }
260
261 // Repeat coalescing2 with WriteFlags::EOR
262 TEST_F(AsyncSSLSocketWriteTest, write_with_eor1) {
263   int n = 3;
264   auto vec = makeVec({1500, 3, 3});
265   int pos = 0;
266   const size_t initAppBytesWritten = 500;
267   const size_t appEor = initAppBytesWritten + 1506;
268
269   sock_->setAppBytesWritten(initAppBytesWritten);
270   EXPECT_FALSE(sock_->isEorTrackingEnabled());
271   sock_->setEorTracking(true);
272   EXPECT_TRUE(sock_->isEorTrackingEnabled());
273
274   EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
275     // rawBytesWritten after writting initAppBytesWritten + 1500
276     // + some random SSL overhead
277     .WillOnce(Return(3600))
278     // rawBytesWritten after writting last 6 bytes
279     // + some random SSL overhead
280     .WillOnce(Return(3728));
281   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
282     .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int m) {
283           // the first 1500 does not have the EOR byte
284           sock_->checkEor(0, 0);
285           verifyVec(buf, m, pos);
286           pos += m;
287           return m; }));
288   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
289     .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int m) {
290           sock_->checkEor(appEor, 3600 + m);
291           verifyVec(buf, m, pos);
292           pos += m;
293           return m; }));
294
295   uint32_t countWritten = 0;
296   uint32_t partialWritten = 0;
297   sock_->testPerformWrite(vec.get(), n , WriteFlags::EOR,
298                           &countWritten, &partialWritten);
299   EXPECT_EQ(countWritten, n);
300   EXPECT_EQ(partialWritten, 0);
301   sock_->checkEor(0, 0);
302 }
303
304 // coalescing with left over at the last chunk
305 // WriteFlags::EOR turned on
306 TEST_F(AsyncSSLSocketWriteTest, write_with_eor2) {
307   int n = 3;
308   auto vec = makeVec({600, 600, 600});
309   int pos = 0;
310   const size_t initAppBytesWritten = 500;
311   const size_t appEor = initAppBytesWritten + 1800;
312
313   sock_->setAppBytesWritten(initAppBytesWritten);
314   sock_->setEorTracking(true);
315
316   EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
317     // rawBytesWritten after writting initAppBytesWritten +  1500 bytes
318     // + some random SSL overhead
319     .WillOnce(Return(3600))
320     // rawBytesWritten after writting last 300 bytes
321     // + some random SSL overhead
322     .WillOnce(Return(4100));
323   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
324     .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int m) {
325           // the first 1500 does not have the EOR byte
326           sock_->checkEor(0, 0);
327           verifyVec(buf, m, pos);
328           pos += m;
329           return m; }));
330   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 300))
331     .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int m) {
332           sock_->checkEor(appEor, 3600 + m);
333           verifyVec(buf, m, pos);
334           pos += m;
335           return m; }));
336
337   uint32_t countWritten = 0;
338   uint32_t partialWritten = 0;
339   sock_->testPerformWrite(vec.get(), n, WriteFlags::EOR,
340                           &countWritten, &partialWritten);
341   EXPECT_EQ(countWritten, n);
342   EXPECT_EQ(partialWritten, 0);
343   sock_->checkEor(0, 0);
344 }
345
346 // WriteFlags::EOR set
347 // One buf in iovec
348 // Partial write at 1000-th byte
349 TEST_F(AsyncSSLSocketWriteTest, write_with_eor3) {
350   int n = 1;
351   auto vec = makeVec({1600});
352   int pos = 0;
353   static constexpr size_t initAppBytesWritten = 500;
354   static constexpr size_t appEor = initAppBytesWritten + 1600;
355
356   sock_->setAppBytesWritten(initAppBytesWritten);
357   sock_->setEorTracking(true);
358
359   EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
360     // rawBytesWritten after the initAppBytesWritten
361     // + some random SSL overhead
362     .WillOnce(Return(2000))
363     // rawBytesWritten after the initAppBytesWritten + 1000 (with 100 overhead)
364     // + some random SSL overhead
365     .WillOnce(Return(3100));
366   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1600))
367     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
368           sock_->checkEor(appEor, 2000 + m);
369           verifyVec(buf, m, pos);
370           pos += 1000;
371           return 1000; }));
372
373   uint32_t countWritten = 0;
374   uint32_t partialWritten = 0;
375   sock_->testPerformWrite(vec.get(), n, WriteFlags::EOR,
376                           &countWritten, &partialWritten);
377   EXPECT_EQ(countWritten, 0);
378   EXPECT_EQ(partialWritten, 1000);
379   sock_->checkEor(appEor, 2000 + 1600);
380   consumeVec(vec.get(), countWritten, partialWritten);
381
382   EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
383     .WillOnce(Return(3100))
384     .WillOnce(Return(3800));
385   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 600))
386     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int m) {
387           sock_->checkEor(appEor, 3100 + m);
388           verifyVec(buf, m, pos);
389           pos += m;
390           return m; }));
391   sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
392                           WriteFlags::EOR,
393                           &countWritten, &partialWritten);
394   EXPECT_EQ(countWritten, n);
395   EXPECT_EQ(partialWritten, 0);
396   sock_->checkEor(0, 0);
397 }
398
399 }