Fix copyright lines
[folly.git] / folly / io / async / test / AsyncSocketTest.h
1 /*
2  * Copyright 2015-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
20 #include <folly/portability/Sockets.h>
21
22 #include <boost/scoped_array.hpp>
23 #include <memory>
24
25 enum StateEnum {
26   STATE_WAITING,
27   STATE_SUCCEEDED,
28   STATE_FAILED
29 };
30
31 typedef std::function<void()> VoidCallback;
32
33 class ConnCallback : public folly::AsyncSocket::ConnectCallback {
34  public:
35   ConnCallback()
36       : state(STATE_WAITING),
37         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
38
39   void connectSuccess() noexcept override {
40     state = STATE_SUCCEEDED;
41     if (successCallback) {
42       successCallback();
43     }
44   }
45
46   void connectErr(const folly::AsyncSocketException& ex) noexcept override {
47     state = STATE_FAILED;
48     exception = ex;
49     if (errorCallback) {
50       errorCallback();
51     }
52   }
53
54   StateEnum state;
55   folly::AsyncSocketException exception;
56   VoidCallback successCallback;
57   VoidCallback errorCallback;
58 };
59
60 class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
61  public:
62   WriteCallback()
63       : state(STATE_WAITING),
64         bytesWritten(0),
65         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
66
67   void writeSuccess() noexcept override {
68     state = STATE_SUCCEEDED;
69     if (successCallback) {
70       successCallback();
71     }
72   }
73
74   void writeErr(size_t nBytesWritten,
75                 const folly::AsyncSocketException& ex) noexcept override {
76     LOG(ERROR) << ex.what();
77     state = STATE_FAILED;
78     this->bytesWritten = nBytesWritten;
79     exception = ex;
80     if (errorCallback) {
81       errorCallback();
82     }
83   }
84
85   StateEnum state;
86   std::atomic<size_t> bytesWritten;
87   folly::AsyncSocketException exception;
88   VoidCallback successCallback;
89   VoidCallback errorCallback;
90 };
91
92 class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
93  public:
94   explicit ReadCallback(size_t _maxBufferSz = 4096)
95       : state(STATE_WAITING),
96         exception(folly::AsyncSocketException::UNKNOWN, "none"),
97         buffers(),
98         maxBufferSz(_maxBufferSz) {}
99
100   ~ReadCallback() override {
101     for (std::vector<Buffer>::iterator it = buffers.begin();
102          it != buffers.end();
103          ++it) {
104       it->free();
105     }
106     currentBuffer.free();
107   }
108
109   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
110     if (!currentBuffer.buffer) {
111       currentBuffer.allocate(maxBufferSz);
112     }
113     *bufReturn = currentBuffer.buffer;
114     *lenReturn = currentBuffer.length;
115   }
116
117   void readDataAvailable(size_t len) noexcept override {
118     currentBuffer.length = len;
119     buffers.push_back(currentBuffer);
120     currentBuffer.reset();
121     if (dataAvailableCallback) {
122       dataAvailableCallback();
123     }
124   }
125
126   void readEOF() noexcept override {
127     state = STATE_SUCCEEDED;
128   }
129
130   void readErr(const folly::AsyncSocketException& ex) noexcept override {
131     state = STATE_FAILED;
132     exception = ex;
133   }
134
135   void verifyData(const char* expected, size_t expectedLen) const {
136     size_t offset = 0;
137     for (size_t idx = 0; idx < buffers.size(); ++idx) {
138       const auto& buf = buffers[idx];
139       size_t cmpLen = std::min(buf.length, expectedLen - offset);
140       CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
141       CHECK_EQ(cmpLen, buf.length);
142       offset += cmpLen;
143     }
144     CHECK_EQ(offset, expectedLen);
145   }
146
147   size_t dataRead() const {
148     size_t ret = 0;
149     for (const auto& buf : buffers) {
150       ret += buf.length;
151     }
152     return ret;
153   }
154
155   class Buffer {
156    public:
157     Buffer() : buffer(nullptr), length(0) {}
158     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
159
160     void reset() {
161       buffer = nullptr;
162       length = 0;
163     }
164     void allocate(size_t len) {
165       assert(buffer == nullptr);
166       this->buffer = static_cast<char*>(malloc(len));
167       this->length = len;
168     }
169     void free() {
170       ::free(buffer);
171       reset();
172     }
173
174     char* buffer;
175     size_t length;
176   };
177
178   StateEnum state;
179   folly::AsyncSocketException exception;
180   std::vector<Buffer> buffers;
181   Buffer currentBuffer;
182   VoidCallback dataAvailableCallback;
183   const size_t maxBufferSz;
184 };
185
186 class BufferCallback : public folly::AsyncTransport::BufferCallback {
187  public:
188   BufferCallback() : buffered_(false), bufferCleared_(false) {}
189
190   void onEgressBuffered() override { buffered_ = true; }
191
192   void onEgressBufferCleared() override { bufferCleared_ = true; }
193
194   bool hasBuffered() const { return buffered_; }
195
196   bool hasBufferCleared() const { return bufferCleared_; }
197
198  private:
199   bool buffered_{false};
200   bool bufferCleared_{false};
201 };
202
203 class ReadVerifier {
204 };
205
206 class TestSendMsgParamsCallback :
207     public folly::AsyncSocket::SendMsgParamsCallback {
208  public:
209   TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
210   : flags_(flags),
211     writeFlags_(folly::WriteFlags::NONE),
212     dataSize_(dataSize),
213     data_(data),
214     queriedFlags_(false),
215     queriedData_(false)
216   {}
217
218   void reset(int flags) {
219     flags_ = flags;
220     writeFlags_ = folly::WriteFlags::NONE;
221     queriedFlags_ = false;
222     queriedData_ = false;
223   }
224
225   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
226                                                                   override {
227     queriedFlags_ = true;
228     if (writeFlags_ == folly::WriteFlags::NONE) {
229       writeFlags_ = flags;
230     } else {
231       assert(flags == writeFlags_);
232     }
233     return flags_;
234   }
235
236   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
237     queriedData_ = true;
238     if (writeFlags_ == folly::WriteFlags::NONE) {
239       writeFlags_ = flags;
240     } else {
241       assert(flags == writeFlags_);
242     }
243     assert(data != nullptr);
244     memcpy(data, data_, dataSize_);
245   }
246
247   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
248     if (writeFlags_ == folly::WriteFlags::NONE) {
249       writeFlags_ = flags;
250     } else {
251       assert(flags == writeFlags_);
252     }
253     return dataSize_;
254   }
255
256   int flags_;
257   folly::WriteFlags writeFlags_;
258   uint32_t dataSize_;
259   void* data_;
260   bool queriedFlags_;
261   bool queriedData_;
262 };
263
264 class TestServer {
265  public:
266   // Create a TestServer.
267   // This immediately starts listening on an ephemeral port.
268   explicit TestServer(bool enableTFO = false, int bufSize = -1) : fd_(-1) {
269     namespace fsp = folly::portability::sockets;
270     fd_ = fsp::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
271     if (fd_ < 0) {
272       throw folly::AsyncSocketException(
273           folly::AsyncSocketException::INTERNAL_ERROR,
274           "failed to create test server socket",
275           errno);
276     }
277     if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
278       throw folly::AsyncSocketException(
279           folly::AsyncSocketException::INTERNAL_ERROR,
280           "failed to put test server socket in "
281           "non-blocking mode",
282           errno);
283     }
284     if (enableTFO) {
285 #if FOLLY_ALLOW_TFO
286       folly::detail::tfo_enable(fd_, 100);
287 #endif
288     }
289
290     struct addrinfo hints, *res;
291     memset(&hints, 0, sizeof(hints));
292     hints.ai_family = AF_INET;
293     hints.ai_socktype = SOCK_STREAM;
294     hints.ai_flags = AI_PASSIVE;
295
296     if (getaddrinfo(nullptr, "0", &hints, &res)) {
297       throw folly::AsyncSocketException(
298           folly::AsyncSocketException::INTERNAL_ERROR,
299           "Attempted to bind address to socket with "
300           "bad getaddrinfo",
301           errno);
302     }
303
304     SCOPE_EXIT {
305       freeaddrinfo(res);
306     };
307
308     if (bufSize > 0) {
309       setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize));
310       setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize));
311     }
312
313     if (bind(fd_, res->ai_addr, res->ai_addrlen)) {
314       throw folly::AsyncSocketException(
315           folly::AsyncSocketException::INTERNAL_ERROR,
316           "failed to bind to async server socket for port 10",
317           errno);
318     }
319
320     if (listen(fd_, 10) != 0) {
321       throw folly::AsyncSocketException(
322           folly::AsyncSocketException::INTERNAL_ERROR,
323           "failed to listen on test server socket",
324           errno);
325     }
326
327     address_.setFromLocalAddress(fd_);
328     // The local address will contain 0.0.0.0.
329     // Change it to 127.0.0.1, so it can be used to connect to the server
330     address_.setFromIpPort("127.0.0.1", address_.getPort());
331   }
332
333   ~TestServer() {
334     if (fd_ != -1) {
335       close(fd_);
336     }
337   }
338
339   // Get the address for connecting to the server
340   const folly::SocketAddress& getAddress() const {
341     return address_;
342   }
343
344   int acceptFD(int timeout=50) {
345     namespace fsp = folly::portability::sockets;
346     struct pollfd pfd;
347     pfd.fd = fd_;
348     pfd.events = POLLIN;
349     int ret = poll(&pfd, 1, timeout);
350     if (ret == 0) {
351       throw folly::AsyncSocketException(
352           folly::AsyncSocketException::INTERNAL_ERROR,
353           "test server accept() timed out");
354     } else if (ret < 0) {
355       throw folly::AsyncSocketException(
356           folly::AsyncSocketException::INTERNAL_ERROR,
357           "test server accept() poll failed",
358           errno);
359     }
360
361     int acceptedFd = fsp::accept(fd_, nullptr, nullptr);
362     if (acceptedFd < 0) {
363       throw folly::AsyncSocketException(
364           folly::AsyncSocketException::INTERNAL_ERROR,
365           "test server accept() failed",
366           errno);
367     }
368
369     return acceptedFd;
370   }
371
372   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
373     int fd = acceptFD(timeout);
374     return std::make_shared<BlockingSocket>(fd);
375   }
376
377   std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,
378                                                   int timeout = 50) {
379     int fd = acceptFD(timeout);
380     return folly::AsyncSocket::newSocket(evb, fd);
381   }
382
383   /**
384    * Accept a connection, read data from it, and verify that it matches the
385    * data in the specified buffer.
386    */
387   void verifyConnection(const char* buf, size_t len) {
388     // accept a connection
389     std::shared_ptr<BlockingSocket> acceptedSocket = accept();
390     // read the data and compare it to the specified buffer
391     boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
392     acceptedSocket->readAll(readbuf.get(), len);
393     CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
394     // make sure we get EOF next
395     uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
396     CHECK_EQ(bytesRead, 0);
397   }
398
399  private:
400   int fd_;
401   folly::SocketAddress address_;
402 };