b84251f787977a4dbd5ce5157a7e0f79e49eaf8b
[folly.git] / folly / io / async / test / AsyncSocketTest.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
20 #include <folly/portability/Sockets.h>
21
22 #include <boost/scoped_array.hpp>
23
24 enum StateEnum {
25   STATE_WAITING,
26   STATE_SUCCEEDED,
27   STATE_FAILED
28 };
29
30 typedef std::function<void()> VoidCallback;
31
32 class ConnCallback : public folly::AsyncSocket::ConnectCallback {
33  public:
34   ConnCallback()
35       : state(STATE_WAITING),
36         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
37
38   void connectSuccess() noexcept override {
39     state = STATE_SUCCEEDED;
40     if (successCallback) {
41       successCallback();
42     }
43   }
44
45   void connectErr(const folly::AsyncSocketException& ex) noexcept override {
46     state = STATE_FAILED;
47     exception = ex;
48     if (errorCallback) {
49       errorCallback();
50     }
51   }
52
53   StateEnum state;
54   folly::AsyncSocketException exception;
55   VoidCallback successCallback;
56   VoidCallback errorCallback;
57 };
58
59 class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
60  public:
61   WriteCallback()
62       : state(STATE_WAITING),
63         bytesWritten(0),
64         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
65
66   void writeSuccess() noexcept override {
67     state = STATE_SUCCEEDED;
68     if (successCallback) {
69       successCallback();
70     }
71   }
72
73   void writeErr(size_t nBytesWritten,
74                 const folly::AsyncSocketException& ex) noexcept override {
75     LOG(ERROR) << ex.what();
76     state = STATE_FAILED;
77     this->bytesWritten = nBytesWritten;
78     exception = ex;
79     if (errorCallback) {
80       errorCallback();
81     }
82   }
83
84   StateEnum state;
85   size_t bytesWritten;
86   folly::AsyncSocketException exception;
87   VoidCallback successCallback;
88   VoidCallback errorCallback;
89 };
90
91 class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
92  public:
93   explicit ReadCallback(size_t _maxBufferSz = 4096)
94       : state(STATE_WAITING),
95         exception(folly::AsyncSocketException::UNKNOWN, "none"),
96         buffers(),
97         maxBufferSz(_maxBufferSz) {}
98
99   ~ReadCallback() override {
100     for (std::vector<Buffer>::iterator it = buffers.begin();
101          it != buffers.end();
102          ++it) {
103       it->free();
104     }
105     currentBuffer.free();
106   }
107
108   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
109     if (!currentBuffer.buffer) {
110       currentBuffer.allocate(maxBufferSz);
111     }
112     *bufReturn = currentBuffer.buffer;
113     *lenReturn = currentBuffer.length;
114   }
115
116   void readDataAvailable(size_t len) noexcept override {
117     currentBuffer.length = len;
118     buffers.push_back(currentBuffer);
119     currentBuffer.reset();
120     if (dataAvailableCallback) {
121       dataAvailableCallback();
122     }
123   }
124
125   void readEOF() noexcept override {
126     state = STATE_SUCCEEDED;
127   }
128
129   void readErr(const folly::AsyncSocketException& ex) noexcept override {
130     state = STATE_FAILED;
131     exception = ex;
132   }
133
134   void verifyData(const char* expected, size_t expectedLen) const {
135     size_t offset = 0;
136     for (size_t idx = 0; idx < buffers.size(); ++idx) {
137       const auto& buf = buffers[idx];
138       size_t cmpLen = std::min(buf.length, expectedLen - offset);
139       CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
140       CHECK_EQ(cmpLen, buf.length);
141       offset += cmpLen;
142     }
143     CHECK_EQ(offset, expectedLen);
144   }
145
146   size_t dataRead() const {
147     size_t ret = 0;
148     for (const auto& buf : buffers) {
149       ret += buf.length;
150     }
151     return ret;
152   }
153
154   class Buffer {
155    public:
156     Buffer() : buffer(nullptr), length(0) {}
157     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
158
159     void reset() {
160       buffer = nullptr;
161       length = 0;
162     }
163     void allocate(size_t len) {
164       assert(buffer == nullptr);
165       this->buffer = static_cast<char*>(malloc(len));
166       this->length = len;
167     }
168     void free() {
169       ::free(buffer);
170       reset();
171     }
172
173     char* buffer;
174     size_t length;
175   };
176
177   StateEnum state;
178   folly::AsyncSocketException exception;
179   std::vector<Buffer> buffers;
180   Buffer currentBuffer;
181   VoidCallback dataAvailableCallback;
182   const size_t maxBufferSz;
183 };
184
185 class BufferCallback : public folly::AsyncTransport::BufferCallback {
186  public:
187   BufferCallback() : buffered_(false), bufferCleared_(false) {}
188
189   void onEgressBuffered() override { buffered_ = true; }
190
191   void onEgressBufferCleared() override { bufferCleared_ = true; }
192
193   bool hasBuffered() const { return buffered_; }
194
195   bool hasBufferCleared() const { return bufferCleared_; }
196
197  private:
198   bool buffered_{false};
199   bool bufferCleared_{false};
200 };
201
202 class ReadVerifier {
203 };
204
205 class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
206  public:
207   TestErrMessageCallback()
208     : exception_(folly::AsyncSocketException::UNKNOWN, "none")
209   {}
210
211   void errMessage(const cmsghdr& cmsg) noexcept override {
212     if (cmsg.cmsg_level == SOL_SOCKET &&
213       cmsg.cmsg_type == SCM_TIMESTAMPING) {
214       gotTimestamp_++;
215       checkResetCallback();
216     } else if (
217       (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
218       (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
219       gotByteSeq_++;
220       checkResetCallback();
221     }
222   }
223
224   void errMessageError(
225       const folly::AsyncSocketException& ex) noexcept override {
226     exception_ = ex;
227   }
228
229   void checkResetCallback() noexcept {
230     if (socket_ != nullptr && resetAfter_ != -1 &&
231         gotTimestamp_ + gotByteSeq_ == resetAfter_) {
232       socket_->setErrMessageCB(nullptr);
233     }
234   }
235
236   folly::AsyncSocket* socket_{nullptr};
237   folly::AsyncSocketException exception_;
238   int gotTimestamp_{0};
239   int gotByteSeq_{0};
240   int resetAfter_{-1};
241 };
242
243 class TestSendMsgParamsCallback :
244     public folly::AsyncSocket::SendMsgParamsCallback {
245  public:
246   TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
247   : flags_(flags),
248     writeFlags_(folly::WriteFlags::NONE),
249     dataSize_(dataSize),
250     data_(data),
251     queriedFlags_(false),
252     queriedData_(false)
253   {}
254
255   void reset(int flags) {
256     flags_ = flags;
257     writeFlags_ = folly::WriteFlags::NONE;
258     queriedFlags_ = false;
259     queriedData_ = false;
260   }
261
262   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
263                                                                   override {
264     queriedFlags_ = true;
265     if (writeFlags_ == folly::WriteFlags::NONE) {
266       writeFlags_ = flags;
267     } else {
268       assert(flags == writeFlags_);
269     }
270     return flags_;
271   }
272
273   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
274     queriedData_ = true;
275     if (writeFlags_ == folly::WriteFlags::NONE) {
276       writeFlags_ = flags;
277     } else {
278       assert(flags == writeFlags_);
279     }
280     assert(data != nullptr);
281     memcpy(data, data_, dataSize_);
282   }
283
284   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
285     if (writeFlags_ == folly::WriteFlags::NONE) {
286       writeFlags_ = flags;
287     } else {
288       assert(flags == writeFlags_);
289     }
290     return dataSize_;
291   }
292
293   int flags_;
294   folly::WriteFlags writeFlags_;
295   uint32_t dataSize_;
296   void* data_;
297   bool queriedFlags_;
298   bool queriedData_;
299 };
300
301 class TestServer {
302  public:
303   // Create a TestServer.
304   // This immediately starts listening on an ephemeral port.
305   explicit TestServer(bool enableTFO = false, int bufSize = -1) : fd_(-1) {
306     namespace fsp = folly::portability::sockets;
307     fd_ = fsp::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
308     if (fd_ < 0) {
309       throw folly::AsyncSocketException(
310           folly::AsyncSocketException::INTERNAL_ERROR,
311           "failed to create test server socket",
312           errno);
313     }
314     if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
315       throw folly::AsyncSocketException(
316           folly::AsyncSocketException::INTERNAL_ERROR,
317           "failed to put test server socket in "
318           "non-blocking mode",
319           errno);
320     }
321     if (enableTFO) {
322 #if FOLLY_ALLOW_TFO
323       folly::detail::tfo_enable(fd_, 100);
324 #endif
325     }
326
327     struct addrinfo hints, *res;
328     memset(&hints, 0, sizeof(hints));
329     hints.ai_family = AF_INET;
330     hints.ai_socktype = SOCK_STREAM;
331     hints.ai_flags = AI_PASSIVE;
332
333     if (getaddrinfo(nullptr, "0", &hints, &res)) {
334       throw folly::AsyncSocketException(
335           folly::AsyncSocketException::INTERNAL_ERROR,
336           "Attempted to bind address to socket with "
337           "bad getaddrinfo",
338           errno);
339     }
340
341     SCOPE_EXIT {
342       freeaddrinfo(res);
343     };
344
345     if (bufSize > 0) {
346       setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize));
347       setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize));
348     }
349
350     if (bind(fd_, res->ai_addr, res->ai_addrlen)) {
351       throw folly::AsyncSocketException(
352           folly::AsyncSocketException::INTERNAL_ERROR,
353           "failed to bind to async server socket for port 10",
354           errno);
355     }
356
357     if (listen(fd_, 10) != 0) {
358       throw folly::AsyncSocketException(
359           folly::AsyncSocketException::INTERNAL_ERROR,
360           "failed to listen on test server socket",
361           errno);
362     }
363
364     address_.setFromLocalAddress(fd_);
365     // The local address will contain 0.0.0.0.
366     // Change it to 127.0.0.1, so it can be used to connect to the server
367     address_.setFromIpPort("127.0.0.1", address_.getPort());
368   }
369
370   ~TestServer() {
371     if (fd_ != -1) {
372       close(fd_);
373     }
374   }
375
376   // Get the address for connecting to the server
377   const folly::SocketAddress& getAddress() const {
378     return address_;
379   }
380
381   int acceptFD(int timeout=50) {
382     namespace fsp = folly::portability::sockets;
383     struct pollfd pfd;
384     pfd.fd = fd_;
385     pfd.events = POLLIN;
386     int ret = poll(&pfd, 1, timeout);
387     if (ret == 0) {
388       throw folly::AsyncSocketException(
389           folly::AsyncSocketException::INTERNAL_ERROR,
390           "test server accept() timed out");
391     } else if (ret < 0) {
392       throw folly::AsyncSocketException(
393           folly::AsyncSocketException::INTERNAL_ERROR,
394           "test server accept() poll failed",
395           errno);
396     }
397
398     int acceptedFd = fsp::accept(fd_, nullptr, nullptr);
399     if (acceptedFd < 0) {
400       throw folly::AsyncSocketException(
401           folly::AsyncSocketException::INTERNAL_ERROR,
402           "test server accept() failed",
403           errno);
404     }
405
406     return acceptedFd;
407   }
408
409   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
410     int fd = acceptFD(timeout);
411     return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
412   }
413
414   std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,
415                                                   int timeout = 50) {
416     int fd = acceptFD(timeout);
417     return folly::AsyncSocket::newSocket(evb, fd);
418   }
419
420   /**
421    * Accept a connection, read data from it, and verify that it matches the
422    * data in the specified buffer.
423    */
424   void verifyConnection(const char* buf, size_t len) {
425     // accept a connection
426     std::shared_ptr<BlockingSocket> acceptedSocket = accept();
427     // read the data and compare it to the specified buffer
428     boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
429     acceptedSocket->readAll(readbuf.get(), len);
430     CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
431     // make sure we get EOF next
432     uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
433     CHECK_EQ(bytesRead, 0);
434   }
435
436  private:
437   int fd_;
438   folly::SocketAddress address_;
439 };