Timestamping callback interface in folly::AsyncSocket
[folly.git] / folly / io / async / test / AsyncSocketTest.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
20 #include <folly/portability/Sockets.h>
21
22 #include <boost/scoped_array.hpp>
23
24 enum StateEnum {
25   STATE_WAITING,
26   STATE_SUCCEEDED,
27   STATE_FAILED
28 };
29
30 typedef std::function<void()> VoidCallback;
31
32 class ConnCallback : public folly::AsyncSocket::ConnectCallback {
33  public:
34   ConnCallback()
35       : state(STATE_WAITING),
36         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
37
38   void connectSuccess() noexcept override {
39     state = STATE_SUCCEEDED;
40     if (successCallback) {
41       successCallback();
42     }
43   }
44
45   void connectErr(const folly::AsyncSocketException& ex) noexcept override {
46     state = STATE_FAILED;
47     exception = ex;
48     if (errorCallback) {
49       errorCallback();
50     }
51   }
52
53   StateEnum state;
54   folly::AsyncSocketException exception;
55   VoidCallback successCallback;
56   VoidCallback errorCallback;
57 };
58
59 class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
60  public:
61   WriteCallback()
62       : state(STATE_WAITING),
63         bytesWritten(0),
64         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
65
66   void writeSuccess() noexcept override {
67     state = STATE_SUCCEEDED;
68     if (successCallback) {
69       successCallback();
70     }
71   }
72
73   void writeErr(size_t nBytesWritten,
74                 const folly::AsyncSocketException& ex) noexcept override {
75     LOG(ERROR) << ex.what();
76     state = STATE_FAILED;
77     this->bytesWritten = nBytesWritten;
78     exception = ex;
79     if (errorCallback) {
80       errorCallback();
81     }
82   }
83
84   StateEnum state;
85   size_t bytesWritten;
86   folly::AsyncSocketException exception;
87   VoidCallback successCallback;
88   VoidCallback errorCallback;
89 };
90
91 class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
92  public:
93   explicit ReadCallback(size_t _maxBufferSz = 4096)
94       : state(STATE_WAITING),
95         exception(folly::AsyncSocketException::UNKNOWN, "none"),
96         buffers(),
97         maxBufferSz(_maxBufferSz) {}
98
99   ~ReadCallback() {
100     for (std::vector<Buffer>::iterator it = buffers.begin();
101          it != buffers.end();
102          ++it) {
103       it->free();
104     }
105     currentBuffer.free();
106   }
107
108   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
109     if (!currentBuffer.buffer) {
110       currentBuffer.allocate(maxBufferSz);
111     }
112     *bufReturn = currentBuffer.buffer;
113     *lenReturn = currentBuffer.length;
114   }
115
116   void readDataAvailable(size_t len) noexcept override {
117     currentBuffer.length = len;
118     buffers.push_back(currentBuffer);
119     currentBuffer.reset();
120     if (dataAvailableCallback) {
121       dataAvailableCallback();
122     }
123   }
124
125   void readEOF() noexcept override {
126     state = STATE_SUCCEEDED;
127   }
128
129   void readErr(const folly::AsyncSocketException& ex) noexcept override {
130     state = STATE_FAILED;
131     exception = ex;
132   }
133
134   void verifyData(const char* expected, size_t expectedLen) const {
135     size_t offset = 0;
136     for (size_t idx = 0; idx < buffers.size(); ++idx) {
137       const auto& buf = buffers[idx];
138       size_t cmpLen = std::min(buf.length, expectedLen - offset);
139       CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
140       CHECK_EQ(cmpLen, buf.length);
141       offset += cmpLen;
142     }
143     CHECK_EQ(offset, expectedLen);
144   }
145
146   size_t dataRead() const {
147     size_t ret = 0;
148     for (const auto& buf : buffers) {
149       ret += buf.length;
150     }
151     return ret;
152   }
153
154   class Buffer {
155    public:
156     Buffer() : buffer(nullptr), length(0) {}
157     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
158
159     void reset() {
160       buffer = nullptr;
161       length = 0;
162     }
163     void allocate(size_t len) {
164       assert(buffer == nullptr);
165       this->buffer = static_cast<char*>(malloc(len));
166       this->length = len;
167     }
168     void free() {
169       ::free(buffer);
170       reset();
171     }
172
173     char* buffer;
174     size_t length;
175   };
176
177   StateEnum state;
178   folly::AsyncSocketException exception;
179   std::vector<Buffer> buffers;
180   Buffer currentBuffer;
181   VoidCallback dataAvailableCallback;
182   const size_t maxBufferSz;
183 };
184
185 class BufferCallback : public folly::AsyncTransport::BufferCallback {
186  public:
187   BufferCallback() : buffered_(false), bufferCleared_(false) {}
188
189   void onEgressBuffered() override { buffered_ = true; }
190
191   void onEgressBufferCleared() override { bufferCleared_ = true; }
192
193   bool hasBuffered() const { return buffered_; }
194
195   bool hasBufferCleared() const { return bufferCleared_; }
196
197  private:
198   bool buffered_{false};
199   bool bufferCleared_{false};
200 };
201
202 class ReadVerifier {
203 };
204
205 class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
206  public:
207   TestErrMessageCallback()
208     : exception_(folly::AsyncSocketException::UNKNOWN, "none")
209   {}
210
211   void errMessage(const cmsghdr& cmsg) noexcept override {
212     if (cmsg.cmsg_level == SOL_SOCKET &&
213       cmsg.cmsg_type == SCM_TIMESTAMPING) {
214       gotTimestamp_ = true;
215     } else if (
216       (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
217       (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
218       gotByteSeq_ = true;
219     }
220   }
221
222   void errMessageError(
223       const folly::AsyncSocketException& ex) noexcept override {
224     exception_ = ex;
225   }
226
227   folly::AsyncSocketException exception_;
228   bool gotTimestamp_{false};
229   bool gotByteSeq_{false};
230 };
231
232 class TestServer {
233  public:
234   // Create a TestServer.
235   // This immediately starts listening on an ephemeral port.
236   explicit TestServer(bool enableTFO = false, int bufSize = -1) : fd_(-1) {
237     namespace fsp = folly::portability::sockets;
238     fd_ = fsp::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
239     if (fd_ < 0) {
240       throw folly::AsyncSocketException(
241           folly::AsyncSocketException::INTERNAL_ERROR,
242           "failed to create test server socket",
243           errno);
244     }
245     if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
246       throw folly::AsyncSocketException(
247           folly::AsyncSocketException::INTERNAL_ERROR,
248           "failed to put test server socket in "
249           "non-blocking mode",
250           errno);
251     }
252     if (enableTFO) {
253 #if FOLLY_ALLOW_TFO
254       folly::detail::tfo_enable(fd_, 100);
255 #endif
256     }
257
258     struct addrinfo hints, *res;
259     memset(&hints, 0, sizeof(hints));
260     hints.ai_family = AF_INET;
261     hints.ai_socktype = SOCK_STREAM;
262     hints.ai_flags = AI_PASSIVE;
263
264     if (getaddrinfo(nullptr, "0", &hints, &res)) {
265       throw folly::AsyncSocketException(
266           folly::AsyncSocketException::INTERNAL_ERROR,
267           "Attempted to bind address to socket with "
268           "bad getaddrinfo",
269           errno);
270     }
271
272     SCOPE_EXIT {
273       freeaddrinfo(res);
274     };
275
276     if (bufSize > 0) {
277       setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize));
278       setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize));
279     }
280
281     if (bind(fd_, res->ai_addr, res->ai_addrlen)) {
282       throw folly::AsyncSocketException(
283           folly::AsyncSocketException::INTERNAL_ERROR,
284           "failed to bind to async server socket for port 10",
285           errno);
286     }
287
288     if (listen(fd_, 10) != 0) {
289       throw folly::AsyncSocketException(
290           folly::AsyncSocketException::INTERNAL_ERROR,
291           "failed to listen on test server socket",
292           errno);
293     }
294
295     address_.setFromLocalAddress(fd_);
296     // The local address will contain 0.0.0.0.
297     // Change it to 127.0.0.1, so it can be used to connect to the server
298     address_.setFromIpPort("127.0.0.1", address_.getPort());
299   }
300
301   ~TestServer() {
302     if (fd_ != -1) {
303       close(fd_);
304     }
305   }
306
307   // Get the address for connecting to the server
308   const folly::SocketAddress& getAddress() const {
309     return address_;
310   }
311
312   int acceptFD(int timeout=50) {
313     namespace fsp = folly::portability::sockets;
314     struct pollfd pfd;
315     pfd.fd = fd_;
316     pfd.events = POLLIN;
317     int ret = poll(&pfd, 1, timeout);
318     if (ret == 0) {
319       throw folly::AsyncSocketException(
320           folly::AsyncSocketException::INTERNAL_ERROR,
321           "test server accept() timed out");
322     } else if (ret < 0) {
323       throw folly::AsyncSocketException(
324           folly::AsyncSocketException::INTERNAL_ERROR,
325           "test server accept() poll failed",
326           errno);
327     }
328
329     int acceptedFd = fsp::accept(fd_, nullptr, nullptr);
330     if (acceptedFd < 0) {
331       throw folly::AsyncSocketException(
332           folly::AsyncSocketException::INTERNAL_ERROR,
333           "test server accept() failed",
334           errno);
335     }
336
337     return acceptedFd;
338   }
339
340   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
341     int fd = acceptFD(timeout);
342     return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
343   }
344
345   std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,
346                                                   int timeout = 50) {
347     int fd = acceptFD(timeout);
348     return folly::AsyncSocket::newSocket(evb, fd);
349   }
350
351   /**
352    * Accept a connection, read data from it, and verify that it matches the
353    * data in the specified buffer.
354    */
355   void verifyConnection(const char* buf, size_t len) {
356     // accept a connection
357     std::shared_ptr<BlockingSocket> acceptedSocket = accept();
358     // read the data and compare it to the specified buffer
359     boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
360     acceptedSocket->readAll(readbuf.get(), len);
361     CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
362     // make sure we get EOF next
363     uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
364     CHECK_EQ(bytesRead, 0);
365   }
366
367  private:
368   int fd_;
369   folly::SocketAddress address_;
370 };