Fix the build of the socket tests on Windows
[folly.git] / folly / io / async / test / AsyncSocketTest.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
20 #include <folly/portability/Sockets.h>
21
22 #include <boost/scoped_array.hpp>
23
24 enum StateEnum {
25   STATE_WAITING,
26   STATE_SUCCEEDED,
27   STATE_FAILED
28 };
29
30 typedef std::function<void()> VoidCallback;
31
32 class ConnCallback : public folly::AsyncSocket::ConnectCallback {
33  public:
34   ConnCallback()
35       : state(STATE_WAITING),
36         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
37
38   void connectSuccess() noexcept override {
39     state = STATE_SUCCEEDED;
40     if (successCallback) {
41       successCallback();
42     }
43   }
44
45   void connectErr(const folly::AsyncSocketException& ex) noexcept override {
46     state = STATE_FAILED;
47     exception = ex;
48     if (errorCallback) {
49       errorCallback();
50     }
51   }
52
53   StateEnum state;
54   folly::AsyncSocketException exception;
55   VoidCallback successCallback;
56   VoidCallback errorCallback;
57 };
58
59 class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
60  public:
61   WriteCallback()
62       : state(STATE_WAITING),
63         bytesWritten(0),
64         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
65
66   void writeSuccess() noexcept override {
67     state = STATE_SUCCEEDED;
68     if (successCallback) {
69       successCallback();
70     }
71   }
72
73   void writeErr(size_t nBytesWritten,
74                 const folly::AsyncSocketException& ex) noexcept override {
75     LOG(ERROR) << ex.what();
76     state = STATE_FAILED;
77     this->bytesWritten = nBytesWritten;
78     exception = ex;
79     if (errorCallback) {
80       errorCallback();
81     }
82   }
83
84   StateEnum state;
85   size_t bytesWritten;
86   folly::AsyncSocketException exception;
87   VoidCallback successCallback;
88   VoidCallback errorCallback;
89 };
90
91 class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
92  public:
93   explicit ReadCallback(size_t _maxBufferSz = 4096)
94       : state(STATE_WAITING),
95         exception(folly::AsyncSocketException::UNKNOWN, "none"),
96         buffers(),
97         maxBufferSz(_maxBufferSz) {}
98
99   ~ReadCallback() override {
100     for (std::vector<Buffer>::iterator it = buffers.begin();
101          it != buffers.end();
102          ++it) {
103       it->free();
104     }
105     currentBuffer.free();
106   }
107
108   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
109     if (!currentBuffer.buffer) {
110       currentBuffer.allocate(maxBufferSz);
111     }
112     *bufReturn = currentBuffer.buffer;
113     *lenReturn = currentBuffer.length;
114   }
115
116   void readDataAvailable(size_t len) noexcept override {
117     currentBuffer.length = len;
118     buffers.push_back(currentBuffer);
119     currentBuffer.reset();
120     if (dataAvailableCallback) {
121       dataAvailableCallback();
122     }
123   }
124
125   void readEOF() noexcept override {
126     state = STATE_SUCCEEDED;
127   }
128
129   void readErr(const folly::AsyncSocketException& ex) noexcept override {
130     state = STATE_FAILED;
131     exception = ex;
132   }
133
134   void verifyData(const char* expected, size_t expectedLen) const {
135     size_t offset = 0;
136     for (size_t idx = 0; idx < buffers.size(); ++idx) {
137       const auto& buf = buffers[idx];
138       size_t cmpLen = std::min(buf.length, expectedLen - offset);
139       CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
140       CHECK_EQ(cmpLen, buf.length);
141       offset += cmpLen;
142     }
143     CHECK_EQ(offset, expectedLen);
144   }
145
146   size_t dataRead() const {
147     size_t ret = 0;
148     for (const auto& buf : buffers) {
149       ret += buf.length;
150     }
151     return ret;
152   }
153
154   class Buffer {
155    public:
156     Buffer() : buffer(nullptr), length(0) {}
157     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
158
159     void reset() {
160       buffer = nullptr;
161       length = 0;
162     }
163     void allocate(size_t len) {
164       assert(buffer == nullptr);
165       this->buffer = static_cast<char*>(malloc(len));
166       this->length = len;
167     }
168     void free() {
169       ::free(buffer);
170       reset();
171     }
172
173     char* buffer;
174     size_t length;
175   };
176
177   StateEnum state;
178   folly::AsyncSocketException exception;
179   std::vector<Buffer> buffers;
180   Buffer currentBuffer;
181   VoidCallback dataAvailableCallback;
182   const size_t maxBufferSz;
183 };
184
185 class BufferCallback : public folly::AsyncTransport::BufferCallback {
186  public:
187   BufferCallback() : buffered_(false), bufferCleared_(false) {}
188
189   void onEgressBuffered() override { buffered_ = true; }
190
191   void onEgressBufferCleared() override { bufferCleared_ = true; }
192
193   bool hasBuffered() const { return buffered_; }
194
195   bool hasBufferCleared() const { return bufferCleared_; }
196
197  private:
198   bool buffered_{false};
199   bool bufferCleared_{false};
200 };
201
202 class ReadVerifier {
203 };
204
205 class TestSendMsgParamsCallback :
206     public folly::AsyncSocket::SendMsgParamsCallback {
207  public:
208   TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
209   : flags_(flags),
210     writeFlags_(folly::WriteFlags::NONE),
211     dataSize_(dataSize),
212     data_(data),
213     queriedFlags_(false),
214     queriedData_(false)
215   {}
216
217   void reset(int flags) {
218     flags_ = flags;
219     writeFlags_ = folly::WriteFlags::NONE;
220     queriedFlags_ = false;
221     queriedData_ = false;
222   }
223
224   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
225                                                                   override {
226     queriedFlags_ = true;
227     if (writeFlags_ == folly::WriteFlags::NONE) {
228       writeFlags_ = flags;
229     } else {
230       assert(flags == writeFlags_);
231     }
232     return flags_;
233   }
234
235   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
236     queriedData_ = true;
237     if (writeFlags_ == folly::WriteFlags::NONE) {
238       writeFlags_ = flags;
239     } else {
240       assert(flags == writeFlags_);
241     }
242     assert(data != nullptr);
243     memcpy(data, data_, dataSize_);
244   }
245
246   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
247     if (writeFlags_ == folly::WriteFlags::NONE) {
248       writeFlags_ = flags;
249     } else {
250       assert(flags == writeFlags_);
251     }
252     return dataSize_;
253   }
254
255   int flags_;
256   folly::WriteFlags writeFlags_;
257   uint32_t dataSize_;
258   void* data_;
259   bool queriedFlags_;
260   bool queriedData_;
261 };
262
263 class TestServer {
264  public:
265   // Create a TestServer.
266   // This immediately starts listening on an ephemeral port.
267   explicit TestServer(bool enableTFO = false, int bufSize = -1) : fd_(-1) {
268     namespace fsp = folly::portability::sockets;
269     fd_ = fsp::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
270     if (fd_ < 0) {
271       throw folly::AsyncSocketException(
272           folly::AsyncSocketException::INTERNAL_ERROR,
273           "failed to create test server socket",
274           errno);
275     }
276     if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
277       throw folly::AsyncSocketException(
278           folly::AsyncSocketException::INTERNAL_ERROR,
279           "failed to put test server socket in "
280           "non-blocking mode",
281           errno);
282     }
283     if (enableTFO) {
284 #if FOLLY_ALLOW_TFO
285       folly::detail::tfo_enable(fd_, 100);
286 #endif
287     }
288
289     struct addrinfo hints, *res;
290     memset(&hints, 0, sizeof(hints));
291     hints.ai_family = AF_INET;
292     hints.ai_socktype = SOCK_STREAM;
293     hints.ai_flags = AI_PASSIVE;
294
295     if (getaddrinfo(nullptr, "0", &hints, &res)) {
296       throw folly::AsyncSocketException(
297           folly::AsyncSocketException::INTERNAL_ERROR,
298           "Attempted to bind address to socket with "
299           "bad getaddrinfo",
300           errno);
301     }
302
303     SCOPE_EXIT {
304       freeaddrinfo(res);
305     };
306
307     if (bufSize > 0) {
308       setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize));
309       setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize));
310     }
311
312     if (bind(fd_, res->ai_addr, res->ai_addrlen)) {
313       throw folly::AsyncSocketException(
314           folly::AsyncSocketException::INTERNAL_ERROR,
315           "failed to bind to async server socket for port 10",
316           errno);
317     }
318
319     if (listen(fd_, 10) != 0) {
320       throw folly::AsyncSocketException(
321           folly::AsyncSocketException::INTERNAL_ERROR,
322           "failed to listen on test server socket",
323           errno);
324     }
325
326     address_.setFromLocalAddress(fd_);
327     // The local address will contain 0.0.0.0.
328     // Change it to 127.0.0.1, so it can be used to connect to the server
329     address_.setFromIpPort("127.0.0.1", address_.getPort());
330   }
331
332   ~TestServer() {
333     if (fd_ != -1) {
334       close(fd_);
335     }
336   }
337
338   // Get the address for connecting to the server
339   const folly::SocketAddress& getAddress() const {
340     return address_;
341   }
342
343   int acceptFD(int timeout=50) {
344     namespace fsp = folly::portability::sockets;
345     struct pollfd pfd;
346     pfd.fd = fd_;
347     pfd.events = POLLIN;
348     int ret = poll(&pfd, 1, timeout);
349     if (ret == 0) {
350       throw folly::AsyncSocketException(
351           folly::AsyncSocketException::INTERNAL_ERROR,
352           "test server accept() timed out");
353     } else if (ret < 0) {
354       throw folly::AsyncSocketException(
355           folly::AsyncSocketException::INTERNAL_ERROR,
356           "test server accept() poll failed",
357           errno);
358     }
359
360     int acceptedFd = fsp::accept(fd_, nullptr, nullptr);
361     if (acceptedFd < 0) {
362       throw folly::AsyncSocketException(
363           folly::AsyncSocketException::INTERNAL_ERROR,
364           "test server accept() failed",
365           errno);
366     }
367
368     return acceptedFd;
369   }
370
371   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
372     int fd = acceptFD(timeout);
373     return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
374   }
375
376   std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,
377                                                   int timeout = 50) {
378     int fd = acceptFD(timeout);
379     return folly::AsyncSocket::newSocket(evb, fd);
380   }
381
382   /**
383    * Accept a connection, read data from it, and verify that it matches the
384    * data in the specified buffer.
385    */
386   void verifyConnection(const char* buf, size_t len) {
387     // accept a connection
388     std::shared_ptr<BlockingSocket> acceptedSocket = accept();
389     // read the data and compare it to the specified buffer
390     boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
391     acceptedSocket->readAll(readbuf.get(), len);
392     CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
393     // make sure we get EOF next
394     uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
395     CHECK_EQ(bytesRead, 0);
396   }
397
398  private:
399   int fd_;
400   folly::SocketAddress address_;
401 };