Move AsyncSocket tests from thrift to folly
[folly.git] / folly / io / async / test / AsyncSSLSocketWriteTest.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/Foreach.h>
17 #include <folly/io/Cursor.h>
18 #include <folly/io/async/AsyncSSLSocket.h>
19 #include <folly/io/async/AsyncSocket.h>
20 #include <folly/io/async/EventBase.h>
21
22 #include <gtest/gtest.h>
23 #include <gmock/gmock.h>
24 #include <string>
25 #include <vector>
26
27 using std::string;
28 using namespace testing;
29
30 namespace folly {
31
32 class MockAsyncSSLSocket : public AsyncSSLSocket{
33  public:
34   static std::shared_ptr<MockAsyncSSLSocket> newSocket(
35     const std::shared_ptr<SSLContext>& ctx,
36     EventBase* evb) {
37     auto sock = std::shared_ptr<MockAsyncSSLSocket>(
38       new MockAsyncSSLSocket(ctx, evb),
39       Destructor());
40     sock->ssl_ = SSL_new(ctx->getSSLCtx());
41     SSL_set_fd(sock->ssl_, -1);
42     return sock;
43   }
44
45   // Fake constructor sets the state to established without call to connect
46   // or accept
47   MockAsyncSSLSocket(const std::shared_ptr<SSLContext>& ctx,
48                       EventBase* evb)
49       : AsyncSocket(evb), AsyncSSLSocket(ctx, evb) {
50     state_ = AsyncSocket::StateEnum::ESTABLISHED;
51     sslState_ = AsyncSSLSocket::SSLStateEnum::STATE_ESTABLISHED;
52   }
53
54   // mock the calls to SSL_write to see the buffer length and contents
55   MOCK_METHOD3(sslWriteImpl, int(SSL *ssl, const void *buf, int n));
56
57   // mock the calls to getRawBytesWritten()
58   MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
59
60   // public wrapper for protected interface
61   ssize_t testPerformWrite(const iovec* vec, uint32_t count, WriteFlags flags,
62                            uint32_t* countWritten, uint32_t* partialWritten) {
63     return performWrite(vec, count, flags, countWritten, partialWritten);
64   }
65
66   void checkEor(size_t appEor, size_t rawEor) {
67     EXPECT_EQ(appEor, appEorByteNo_);
68     EXPECT_EQ(rawEor, minEorRawByteNo_);
69   }
70
71   void setAppBytesWritten(size_t n) {
72     appBytesWritten_ = n;
73   }
74 };
75
76 class AsyncSSLSocketWriteTest : public testing::Test {
77  public:
78   AsyncSSLSocketWriteTest() :
79       sslContext_(new SSLContext()),
80       sock_(MockAsyncSSLSocket::newSocket(sslContext_, &eventBase_)) {
81     for (int i = 0; i < 500; i++) {
82       memcpy(source_ + i * 26, "abcdefghijklmnopqrstuvwxyz", 26);
83     }
84   }
85
86   // Make an iovec containing chunks of the reference text with requested sizes
87   // for each chunk
88   iovec *makeVec(std::vector<uint32_t> sizes) {
89     iovec *vec = new iovec[sizes.size()];
90     int i = 0;
91     int pos = 0;
92     for (auto size: sizes) {
93       vec[i].iov_base = (void *)(source_ + pos);
94       vec[i++].iov_len = size;
95       pos += size;
96     }
97     return vec;
98   }
99
100   // Verify that the given buf/pos matches the reference text
101   void verifyVec(const void *buf, int n, int pos) {
102     ASSERT_EQ(memcmp(source_ + pos, buf, n), 0);
103   }
104
105   // Update a vec on partial write
106   void consumeVec(iovec *vec, uint32_t countWritten, uint32_t partialWritten) {
107     vec[countWritten].iov_base =
108       ((char *)vec[countWritten].iov_base) + partialWritten;
109     vec[countWritten].iov_len -= partialWritten;
110   }
111
112   EventBase eventBase_;
113   std::shared_ptr<SSLContext> sslContext_;
114   std::shared_ptr<MockAsyncSSLSocket> sock_;
115   char source_[26 * 500];
116 };
117
118
119 // The entire vec fits in one packet
120 TEST_F(AsyncSSLSocketWriteTest, write_coalescing1) {
121   int n = 3;
122   iovec *vec = makeVec({3, 3, 3});
123   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 9))
124     .WillOnce(Invoke([this] (SSL *, const void *buf, int n) {
125           verifyVec(buf, n, 0);
126           return 9; }));
127   uint32_t countWritten = 0;
128   uint32_t partialWritten = 0;
129   sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
130                           &partialWritten);
131   EXPECT_EQ(countWritten, n);
132   EXPECT_EQ(partialWritten, 0);
133 }
134
135 // First packet is full, second two go in one packet
136 TEST_F(AsyncSSLSocketWriteTest, write_coalescing2) {
137   int n = 3;
138   iovec *vec = makeVec({1500, 3, 3});
139   int pos = 0;
140   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
141     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
142           verifyVec(buf, n, pos);
143           pos += n;
144           return n; }));
145   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
146     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
147           verifyVec(buf, n, pos);
148           pos += n;
149           return n; }));
150   uint32_t countWritten = 0;
151   uint32_t partialWritten = 0;
152   sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
153                           &partialWritten);
154   EXPECT_EQ(countWritten, n);
155   EXPECT_EQ(partialWritten, 0);
156 }
157
158 // Two exactly full packets (coalesce ends midway through second chunk)
159 TEST_F(AsyncSSLSocketWriteTest, write_coalescing3) {
160   int n = 3;
161   iovec *vec = makeVec({1000, 1000, 1000});
162   int pos = 0;
163   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
164     .Times(2)
165     .WillRepeatedly(Invoke([this, &pos] (SSL *, const void *buf, int n) {
166           verifyVec(buf, n, pos);
167           pos += n;
168           return n; }));
169   uint32_t countWritten = 0;
170   uint32_t partialWritten = 0;
171   sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
172                           &partialWritten);
173   EXPECT_EQ(countWritten, n);
174   EXPECT_EQ(partialWritten, 0);
175 }
176
177 // Partial write success midway through a coalesced vec
178 TEST_F(AsyncSSLSocketWriteTest, write_coalescing4) {
179   int n = 5;
180   iovec *vec = makeVec({300, 300, 300, 300, 300});
181   int pos = 0;
182   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
183     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
184           verifyVec(buf, n, pos);
185           pos += 1000;
186           return 1000; /* 500 bytes "pending" */ }));
187   uint32_t countWritten = 0;
188   uint32_t partialWritten = 0;
189   sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
190                           &partialWritten);
191   EXPECT_EQ(countWritten, 3);
192   EXPECT_EQ(partialWritten, 100);
193   consumeVec(vec, countWritten, partialWritten);
194   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
195     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
196           verifyVec(buf, n, pos);
197           pos += n;
198           return 500; }));
199   sock_->testPerformWrite(vec + countWritten, n - countWritten,
200                           WriteFlags::NONE,
201                           &countWritten, &partialWritten);
202   EXPECT_EQ(countWritten, 2);
203   EXPECT_EQ(partialWritten, 0);
204 }
205
206 // coalesce ends exactly on a buffer boundary
207 TEST_F(AsyncSSLSocketWriteTest, write_coalescing5) {
208   int n = 3;
209   iovec *vec = makeVec({1000, 500, 500});
210   int pos = 0;
211   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
212     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
213           verifyVec(buf, n, pos);
214           pos += n;
215           return n; }));
216   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
217     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
218           verifyVec(buf, n, pos);
219           pos += n;
220           return n; }));
221   uint32_t countWritten = 0;
222   uint32_t partialWritten = 0;
223   sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
224                           &partialWritten);
225   EXPECT_EQ(countWritten, 3);
226   EXPECT_EQ(partialWritten, 0);
227 }
228
229 // partial write midway through first chunk
230 TEST_F(AsyncSSLSocketWriteTest, write_coalescing6) {
231   int n = 2;
232   iovec *vec = makeVec({1000, 500});
233   int pos = 0;
234   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
235     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
236           verifyVec(buf, n, pos);
237           pos += 700;
238           return 700; }));
239   uint32_t countWritten = 0;
240   uint32_t partialWritten = 0;
241   sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
242                           &partialWritten);
243   EXPECT_EQ(countWritten, 0);
244   EXPECT_EQ(partialWritten, 700);
245   consumeVec(vec, countWritten, partialWritten);
246   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 800))
247     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
248           verifyVec(buf, n, pos);
249           pos += n;
250           return n; }));
251   sock_->testPerformWrite(vec + countWritten, n - countWritten,
252                           WriteFlags::NONE,
253                           &countWritten, &partialWritten);
254   EXPECT_EQ(countWritten, 2);
255   EXPECT_EQ(partialWritten, 0);
256 }
257
258 // Repeat coalescing2 with WriteFlags::EOR
259 TEST_F(AsyncSSLSocketWriteTest, write_with_eor1) {
260   int n = 3;
261   iovec *vec = makeVec({1500, 3, 3});
262   int pos = 0;
263   const size_t initAppBytesWritten = 500;
264   const size_t appEor = initAppBytesWritten + 1506;
265
266   sock_->setAppBytesWritten(initAppBytesWritten);
267   EXPECT_FALSE(sock_->isEorTrackingEnabled());
268   sock_->setEorTracking(true);
269   EXPECT_TRUE(sock_->isEorTrackingEnabled());
270
271   EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
272     // rawBytesWritten after writting initAppBytesWritten + 1500
273     // + some random SSL overhead
274     .WillOnce(Return(3600))
275     // rawBytesWritten after writting last 6 bytes
276     // + some random SSL overhead
277     .WillOnce(Return(3728));
278   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
279     .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
280           // the first 1500 does not have the EOR byte
281           sock_->checkEor(0, 0);
282           verifyVec(buf, n, pos);
283           pos += n;
284           return n; }));
285   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
286     .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
287           sock_->checkEor(appEor, 3600 + n);
288           verifyVec(buf, n, pos);
289           pos += n;
290           return n; }));
291
292   uint32_t countWritten = 0;
293   uint32_t partialWritten = 0;
294   sock_->testPerformWrite(vec, n , WriteFlags::EOR,
295                           &countWritten, &partialWritten);
296   EXPECT_EQ(countWritten, n);
297   EXPECT_EQ(partialWritten, 0);
298   sock_->checkEor(0, 0);
299 }
300
301 // coalescing with left over at the last chunk
302 // WriteFlags::EOR turned on
303 TEST_F(AsyncSSLSocketWriteTest, write_with_eor2) {
304   int n = 3;
305   iovec *vec = makeVec({600, 600, 600});
306   int pos = 0;
307   const size_t initAppBytesWritten = 500;
308   const size_t appEor = initAppBytesWritten + 1800;
309
310   sock_->setAppBytesWritten(initAppBytesWritten);
311   sock_->setEorTracking(true);
312
313   EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
314     // rawBytesWritten after writting initAppBytesWritten +  1500 bytes
315     // + some random SSL overhead
316     .WillOnce(Return(3600))
317     // rawBytesWritten after writting last 300 bytes
318     // + some random SSL overhead
319     .WillOnce(Return(4100));
320   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
321     .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
322           // the first 1500 does not have the EOR byte
323           sock_->checkEor(0, 0);
324           verifyVec(buf, n, pos);
325           pos += n;
326           return n; }));
327   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 300))
328     .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
329           sock_->checkEor(appEor, 3600 + n);
330           verifyVec(buf, n, pos);
331           pos += n;
332           return n; }));
333
334   uint32_t countWritten = 0;
335   uint32_t partialWritten = 0;
336   sock_->testPerformWrite(vec, n, WriteFlags::EOR,
337                           &countWritten, &partialWritten);
338   EXPECT_EQ(countWritten, n);
339   EXPECT_EQ(partialWritten, 0);
340   sock_->checkEor(0, 0);
341 }
342
343 // WriteFlags::EOR set
344 // One buf in iovec
345 // Partial write at 1000-th byte
346 TEST_F(AsyncSSLSocketWriteTest, write_with_eor3) {
347   int n = 1;
348   iovec *vec = makeVec({1600});
349   int pos = 0;
350   const size_t initAppBytesWritten = 500;
351   const size_t appEor = initAppBytesWritten + 1600;
352
353   sock_->setAppBytesWritten(initAppBytesWritten);
354   sock_->setEorTracking(true);
355
356   EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
357     // rawBytesWritten after the initAppBytesWritten
358     // + some random SSL overhead
359     .WillOnce(Return(2000))
360     // rawBytesWritten after the initAppBytesWritten + 1000 (with 100 overhead)
361     // + some random SSL overhead
362     .WillOnce(Return(3100));
363   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1600))
364     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
365           sock_->checkEor(appEor, 2000 + n);
366           verifyVec(buf, n, pos);
367           pos += 1000;
368           return 1000; }));
369
370   uint32_t countWritten = 0;
371   uint32_t partialWritten = 0;
372   sock_->testPerformWrite(vec, n, WriteFlags::EOR,
373                           &countWritten, &partialWritten);
374   EXPECT_EQ(countWritten, 0);
375   EXPECT_EQ(partialWritten, 1000);
376   sock_->checkEor(appEor, 2000 + 1600);
377   consumeVec(vec, countWritten, partialWritten);
378
379   EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
380     .WillOnce(Return(3100))
381     .WillOnce(Return(3800));
382   EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 600))
383     .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
384           sock_->checkEor(appEor, 3100 + n);
385           verifyVec(buf, n, pos);
386           pos += n;
387           return n; }));
388   sock_->testPerformWrite(vec + countWritten, n - countWritten,
389                           WriteFlags::EOR,
390                           &countWritten, &partialWritten);
391   EXPECT_EQ(countWritten, n);
392   EXPECT_EQ(partialWritten, 0);
393   sock_->checkEor(0, 0);
394 }
395
396 }