8b9bf95fb52327ffc6488e1bfd23df9bdf18db59
[folly.git] / folly / io / async / test / AsyncSocketTest.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
20 #include <folly/portability/Sockets.h>
21
22 #include <boost/scoped_array.hpp>
23
24 enum StateEnum {
25   STATE_WAITING,
26   STATE_SUCCEEDED,
27   STATE_FAILED
28 };
29
30 typedef std::function<void()> VoidCallback;
31
32 class ConnCallback : public folly::AsyncSocket::ConnectCallback {
33  public:
34   ConnCallback()
35       : state(STATE_WAITING),
36         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
37
38   void connectSuccess() noexcept override {
39     state = STATE_SUCCEEDED;
40     if (successCallback) {
41       successCallback();
42     }
43   }
44
45   void connectErr(const folly::AsyncSocketException& ex) noexcept override {
46     state = STATE_FAILED;
47     exception = ex;
48     if (errorCallback) {
49       errorCallback();
50     }
51   }
52
53   StateEnum state;
54   folly::AsyncSocketException exception;
55   VoidCallback successCallback;
56   VoidCallback errorCallback;
57 };
58
59 class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
60  public:
61   WriteCallback()
62       : state(STATE_WAITING),
63         bytesWritten(0),
64         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
65
66   void writeSuccess() noexcept override {
67     state = STATE_SUCCEEDED;
68     if (successCallback) {
69       successCallback();
70     }
71   }
72
73   void writeErr(size_t nBytesWritten,
74                 const folly::AsyncSocketException& ex) noexcept override {
75     LOG(ERROR) << ex.what();
76     state = STATE_FAILED;
77     this->bytesWritten = nBytesWritten;
78     exception = ex;
79     if (errorCallback) {
80       errorCallback();
81     }
82   }
83
84   StateEnum state;
85   size_t bytesWritten;
86   folly::AsyncSocketException exception;
87   VoidCallback successCallback;
88   VoidCallback errorCallback;
89 };
90
91 class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
92  public:
93   explicit ReadCallback(size_t _maxBufferSz = 4096)
94       : state(STATE_WAITING),
95         exception(folly::AsyncSocketException::UNKNOWN, "none"),
96         buffers(),
97         maxBufferSz(_maxBufferSz) {}
98
99   ~ReadCallback() {
100     for (std::vector<Buffer>::iterator it = buffers.begin();
101          it != buffers.end();
102          ++it) {
103       it->free();
104     }
105     currentBuffer.free();
106   }
107
108   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
109     if (!currentBuffer.buffer) {
110       currentBuffer.allocate(maxBufferSz);
111     }
112     *bufReturn = currentBuffer.buffer;
113     *lenReturn = currentBuffer.length;
114   }
115
116   void readDataAvailable(size_t len) noexcept override {
117     currentBuffer.length = len;
118     buffers.push_back(currentBuffer);
119     currentBuffer.reset();
120     if (dataAvailableCallback) {
121       dataAvailableCallback();
122     }
123   }
124
125   void readEOF() noexcept override {
126     state = STATE_SUCCEEDED;
127   }
128
129   void readErr(const folly::AsyncSocketException& ex) noexcept override {
130     state = STATE_FAILED;
131     exception = ex;
132   }
133
134   void verifyData(const char* expected, size_t expectedLen) const {
135     size_t offset = 0;
136     for (size_t idx = 0; idx < buffers.size(); ++idx) {
137       const auto& buf = buffers[idx];
138       size_t cmpLen = std::min(buf.length, expectedLen - offset);
139       CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
140       CHECK_EQ(cmpLen, buf.length);
141       offset += cmpLen;
142     }
143     CHECK_EQ(offset, expectedLen);
144   }
145
146   size_t dataRead() const {
147     size_t ret = 0;
148     for (const auto& buf : buffers) {
149       ret += buf.length;
150     }
151     return ret;
152   }
153
154   class Buffer {
155    public:
156     Buffer() : buffer(nullptr), length(0) {}
157     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
158
159     void reset() {
160       buffer = nullptr;
161       length = 0;
162     }
163     void allocate(size_t len) {
164       assert(buffer == nullptr);
165       this->buffer = static_cast<char*>(malloc(len));
166       this->length = len;
167     }
168     void free() {
169       ::free(buffer);
170       reset();
171     }
172
173     char* buffer;
174     size_t length;
175   };
176
177   StateEnum state;
178   folly::AsyncSocketException exception;
179   std::vector<Buffer> buffers;
180   Buffer currentBuffer;
181   VoidCallback dataAvailableCallback;
182   const size_t maxBufferSz;
183 };
184
185 class BufferCallback : public folly::AsyncTransport::BufferCallback {
186  public:
187   BufferCallback() : buffered_(false), bufferCleared_(false) {}
188
189   void onEgressBuffered() override { buffered_ = true; }
190
191   void onEgressBufferCleared() override { bufferCleared_ = true; }
192
193   bool hasBuffered() const { return buffered_; }
194
195   bool hasBufferCleared() const { return bufferCleared_; }
196
197  private:
198   bool buffered_{false};
199   bool bufferCleared_{false};
200 };
201
202 class ReadVerifier {
203 };
204
205 class TestServer {
206  public:
207   // Create a TestServer.
208   // This immediately starts listening on an ephemeral port.
209   explicit TestServer(bool enableTFO = false, int bufSize = -1) : fd_(-1) {
210     namespace fsp = folly::portability::sockets;
211     fd_ = fsp::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
212     if (fd_ < 0) {
213       throw folly::AsyncSocketException(
214           folly::AsyncSocketException::INTERNAL_ERROR,
215           "failed to create test server socket",
216           errno);
217     }
218     if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
219       throw folly::AsyncSocketException(
220           folly::AsyncSocketException::INTERNAL_ERROR,
221           "failed to put test server socket in "
222           "non-blocking mode",
223           errno);
224     }
225     if (enableTFO) {
226 #if FOLLY_ALLOW_TFO
227       folly::detail::tfo_enable(fd_, 100);
228 #endif
229     }
230
231     struct addrinfo hints, *res;
232     memset(&hints, 0, sizeof(hints));
233     hints.ai_family = AF_INET;
234     hints.ai_socktype = SOCK_STREAM;
235     hints.ai_flags = AI_PASSIVE;
236
237     if (getaddrinfo(nullptr, "0", &hints, &res)) {
238       throw folly::AsyncSocketException(
239           folly::AsyncSocketException::INTERNAL_ERROR,
240           "Attempted to bind address to socket with "
241           "bad getaddrinfo",
242           errno);
243     }
244
245     SCOPE_EXIT {
246       freeaddrinfo(res);
247     };
248
249     if (bufSize > 0) {
250       setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize));
251       setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize));
252     }
253
254     if (bind(fd_, res->ai_addr, res->ai_addrlen)) {
255       throw folly::AsyncSocketException(
256           folly::AsyncSocketException::INTERNAL_ERROR,
257           "failed to bind to async server socket for port 10",
258           errno);
259     }
260
261     if (listen(fd_, 10) != 0) {
262       throw folly::AsyncSocketException(
263           folly::AsyncSocketException::INTERNAL_ERROR,
264           "failed to listen on test server socket",
265           errno);
266     }
267
268     address_.setFromLocalAddress(fd_);
269     // The local address will contain 0.0.0.0.
270     // Change it to 127.0.0.1, so it can be used to connect to the server
271     address_.setFromIpPort("127.0.0.1", address_.getPort());
272   }
273
274   ~TestServer() {
275     if (fd_ != -1) {
276       close(fd_);
277     }
278   }
279
280   // Get the address for connecting to the server
281   const folly::SocketAddress& getAddress() const {
282     return address_;
283   }
284
285   int acceptFD(int timeout=50) {
286     namespace fsp = folly::portability::sockets;
287     struct pollfd pfd;
288     pfd.fd = fd_;
289     pfd.events = POLLIN;
290     int ret = poll(&pfd, 1, timeout);
291     if (ret == 0) {
292       throw folly::AsyncSocketException(
293           folly::AsyncSocketException::INTERNAL_ERROR,
294           "test server accept() timed out");
295     } else if (ret < 0) {
296       throw folly::AsyncSocketException(
297           folly::AsyncSocketException::INTERNAL_ERROR,
298           "test server accept() poll failed",
299           errno);
300     }
301
302     int acceptedFd = fsp::accept(fd_, nullptr, nullptr);
303     if (acceptedFd < 0) {
304       throw folly::AsyncSocketException(
305           folly::AsyncSocketException::INTERNAL_ERROR,
306           "test server accept() failed",
307           errno);
308     }
309
310     return acceptedFd;
311   }
312
313   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
314     int fd = acceptFD(timeout);
315     return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
316   }
317
318   std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,
319                                                   int timeout = 50) {
320     int fd = acceptFD(timeout);
321     return folly::AsyncSocket::newSocket(evb, fd);
322   }
323
324   /**
325    * Accept a connection, read data from it, and verify that it matches the
326    * data in the specified buffer.
327    */
328   void verifyConnection(const char* buf, size_t len) {
329     // accept a connection
330     std::shared_ptr<BlockingSocket> acceptedSocket = accept();
331     // read the data and compare it to the specified buffer
332     boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
333     acceptedSocket->readAll(readbuf.get(), len);
334     CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
335     // make sure we get EOF next
336     uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
337     CHECK_EQ(bytesRead, 0);
338   }
339
340  private:
341   int fd_;
342   folly::SocketAddress address_;
343 };