5d52ad204f68faa2859d91ec4c4c8bb42b58126d
[folly.git] / folly / io / async / test / AsyncSocketTest.h
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
20
21 #include <boost/scoped_array.hpp>
22 #include <poll.h>
23
24 // This is a test-only header
25 /* using override */
26 using namespace folly;
27
28 enum StateEnum {
29   STATE_WAITING,
30   STATE_SUCCEEDED,
31   STATE_FAILED
32 };
33
34 typedef std::function<void()> VoidCallback;
35
36 class ConnCallback : public AsyncSocket::ConnectCallback {
37  public:
38   ConnCallback()
39     : state(STATE_WAITING)
40     , exception(AsyncSocketException::UNKNOWN, "none") {}
41
42   void connectSuccess() noexcept override {
43     state = STATE_SUCCEEDED;
44     if (successCallback) {
45       successCallback();
46     }
47   }
48
49   void connectErr(const AsyncSocketException& ex) noexcept override {
50     state = STATE_FAILED;
51     exception = ex;
52     if (errorCallback) {
53       errorCallback();
54     }
55   }
56
57   StateEnum state;
58   AsyncSocketException exception;
59   VoidCallback successCallback;
60   VoidCallback errorCallback;
61 };
62
63 class BufferCallback : public AsyncTransportWrapper::BufferCallback {
64  public:
65   BufferCallback()
66     : buffered_(false) {}
67
68   void onEgressBuffered() override {
69     buffered_ = true;
70   }
71
72   bool hasBuffered() const {
73     return buffered_;
74   }
75
76  private:
77   bool buffered_{false};
78 };
79
80 class WriteCallback : public AsyncTransportWrapper::WriteCallback {
81  public:
82   WriteCallback()
83     : state(STATE_WAITING)
84     , bytesWritten(0)
85     , exception(AsyncSocketException::UNKNOWN, "none") {}
86
87   void writeSuccess() noexcept override {
88     state = STATE_SUCCEEDED;
89     if (successCallback) {
90       successCallback();
91     }
92   }
93
94   void writeErr(size_t bytesWritten,
95                 const AsyncSocketException& ex) noexcept override {
96     state = STATE_FAILED;
97     this->bytesWritten = bytesWritten;
98     exception = ex;
99     if (errorCallback) {
100       errorCallback();
101     }
102   }
103
104   StateEnum state;
105   size_t bytesWritten;
106   AsyncSocketException exception;
107   VoidCallback successCallback;
108   VoidCallback errorCallback;
109 };
110
111 class ReadCallback : public AsyncTransportWrapper::ReadCallback {
112  public:
113   explicit ReadCallback(size_t _maxBufferSz = 4096)
114     : state(STATE_WAITING)
115     , exception(AsyncSocketException::UNKNOWN, "none")
116     , buffers()
117     , maxBufferSz(_maxBufferSz) {}
118
119   ~ReadCallback() {
120     for (std::vector<Buffer>::iterator it = buffers.begin();
121          it != buffers.end();
122          ++it) {
123       it->free();
124     }
125     currentBuffer.free();
126   }
127
128   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
129     if (!currentBuffer.buffer) {
130       currentBuffer.allocate(maxBufferSz);
131     }
132     *bufReturn = currentBuffer.buffer;
133     *lenReturn = currentBuffer.length;
134   }
135
136   void readDataAvailable(size_t len) noexcept override {
137     currentBuffer.length = len;
138     buffers.push_back(currentBuffer);
139     currentBuffer.reset();
140     if (dataAvailableCallback) {
141       dataAvailableCallback();
142     }
143   }
144
145   void readEOF() noexcept override {
146     state = STATE_SUCCEEDED;
147   }
148
149   void readErr(const AsyncSocketException& ex) noexcept override {
150     state = STATE_FAILED;
151     exception = ex;
152   }
153
154   void verifyData(const char* expected, size_t expectedLen) const {
155     size_t offset = 0;
156     for (size_t idx = 0; idx < buffers.size(); ++idx) {
157       const auto& buf = buffers[idx];
158       size_t cmpLen = std::min(buf.length, expectedLen - offset);
159       CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
160       CHECK_EQ(cmpLen, buf.length);
161       offset += cmpLen;
162     }
163     CHECK_EQ(offset, expectedLen);
164   }
165
166   size_t dataRead() const {
167     size_t ret = 0;
168     for (const auto& buf : buffers) {
169       ret += buf.length;
170     }
171     return ret;
172   }
173
174   class Buffer {
175    public:
176     Buffer() : buffer(nullptr), length(0) {}
177     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
178
179     void reset() {
180       buffer = nullptr;
181       length = 0;
182     }
183     void allocate(size_t length) {
184       assert(buffer == nullptr);
185       this->buffer = static_cast<char*>(malloc(length));
186       this->length = length;
187     }
188     void free() {
189       ::free(buffer);
190       reset();
191     }
192
193     char* buffer;
194     size_t length;
195   };
196
197   StateEnum state;
198   AsyncSocketException exception;
199   std::vector<Buffer> buffers;
200   Buffer currentBuffer;
201   VoidCallback dataAvailableCallback;
202   const size_t maxBufferSz;
203 };
204
205 class ReadVerifier {
206 };
207
208 class TestServer {
209  public:
210   // Create a TestServer.
211   // This immediately starts listening on an ephemeral port.
212   TestServer()
213     : fd_(-1) {
214     fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
215     if (fd_ < 0) {
216       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
217                                 "failed to create test server socket", errno);
218     }
219     if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
220       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
221                                 "failed to put test server socket in "
222                                 "non-blocking mode", errno);
223     }
224     if (listen(fd_, 10) != 0) {
225       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
226                                 "failed to listen on test server socket",
227                                 errno);
228     }
229
230     address_.setFromLocalAddress(fd_);
231     // The local address will contain 0.0.0.0.
232     // Change it to 127.0.0.1, so it can be used to connect to the server
233     address_.setFromIpPort("127.0.0.1", address_.getPort());
234   }
235
236   // Get the address for connecting to the server
237   const folly::SocketAddress& getAddress() const {
238     return address_;
239   }
240
241   int acceptFD(int timeout=50) {
242     struct pollfd pfd;
243     pfd.fd = fd_;
244     pfd.events = POLLIN;
245     int ret = poll(&pfd, 1, timeout);
246     if (ret == 0) {
247       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
248                                 "test server accept() timed out");
249     } else if (ret < 0) {
250       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
251                                 "test server accept() poll failed", errno);
252     }
253
254     int acceptedFd = ::accept(fd_, nullptr, nullptr);
255     if (acceptedFd < 0) {
256       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
257                                 "test server accept() failed", errno);
258     }
259
260     return acceptedFd;
261   }
262
263   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
264     int fd = acceptFD(timeout);
265     return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
266   }
267
268   std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
269     int fd = acceptFD(timeout);
270     return AsyncSocket::newSocket(evb, fd);
271   }
272
273   /**
274    * Accept a connection, read data from it, and verify that it matches the
275    * data in the specified buffer.
276    */
277   void verifyConnection(const char* buf, size_t len) {
278     // accept a connection
279     std::shared_ptr<BlockingSocket> acceptedSocket = accept();
280     // read the data and compare it to the specified buffer
281     boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
282     acceptedSocket->readAll(readbuf.get(), len);
283     CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
284     // make sure we get EOF next
285     uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
286     CHECK_EQ(bytesRead, 0);
287   }
288
289  private:
290   int fd_;
291   folly::SocketAddress address_;
292 };