722be94f35015762e8a1399eb42cba9842e68a46
[folly.git] / folly / io / async / test / AsyncSocketTest.h
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
20
21 #include <boost/scoped_array.hpp>
22 #include <poll.h>
23
24 // This is a test-only header
25 /* using override */
26 using namespace folly;
27
28 enum StateEnum {
29   STATE_WAITING,
30   STATE_SUCCEEDED,
31   STATE_FAILED
32 };
33
34 typedef std::function<void()> VoidCallback;
35
36 class ConnCallback : public AsyncSocket::ConnectCallback {
37  public:
38   ConnCallback()
39     : state(STATE_WAITING)
40     , exception(AsyncSocketException::UNKNOWN, "none") {}
41
42   void connectSuccess() noexcept override {
43     state = STATE_SUCCEEDED;
44     if (successCallback) {
45       successCallback();
46     }
47   }
48
49   void connectErr(const AsyncSocketException& ex) noexcept override {
50     state = STATE_FAILED;
51     exception = ex;
52     if (errorCallback) {
53       errorCallback();
54     }
55   }
56
57   StateEnum state;
58   AsyncSocketException exception;
59   VoidCallback successCallback;
60   VoidCallback errorCallback;
61 };
62
63 class WriteCallback : public AsyncTransportWrapper::WriteCallback {
64  public:
65   WriteCallback()
66     : state(STATE_WAITING)
67     , bytesWritten(0)
68     , exception(AsyncSocketException::UNKNOWN, "none") {}
69
70   void writeSuccess() noexcept override {
71     state = STATE_SUCCEEDED;
72     if (successCallback) {
73       successCallback();
74     }
75   }
76
77   void writeErr(size_t bytesWritten,
78                 const AsyncSocketException& ex) noexcept override {
79     state = STATE_FAILED;
80     this->bytesWritten = bytesWritten;
81     exception = ex;
82     if (errorCallback) {
83       errorCallback();
84     }
85   }
86
87   StateEnum state;
88   size_t bytesWritten;
89   AsyncSocketException exception;
90   VoidCallback successCallback;
91   VoidCallback errorCallback;
92 };
93
94 class ReadCallback : public AsyncTransportWrapper::ReadCallback {
95  public:
96   explicit ReadCallback(size_t _maxBufferSz = 4096)
97     : state(STATE_WAITING)
98     , exception(AsyncSocketException::UNKNOWN, "none")
99     , buffers()
100     , maxBufferSz(_maxBufferSz) {}
101
102   ~ReadCallback() {
103     for (std::vector<Buffer>::iterator it = buffers.begin();
104          it != buffers.end();
105          ++it) {
106       it->free();
107     }
108     currentBuffer.free();
109   }
110
111   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
112     if (!currentBuffer.buffer) {
113       currentBuffer.allocate(maxBufferSz);
114     }
115     *bufReturn = currentBuffer.buffer;
116     *lenReturn = currentBuffer.length;
117   }
118
119   void readDataAvailable(size_t len) noexcept override {
120     currentBuffer.length = len;
121     buffers.push_back(currentBuffer);
122     currentBuffer.reset();
123     if (dataAvailableCallback) {
124       dataAvailableCallback();
125     }
126   }
127
128   void readEOF() noexcept override {
129     state = STATE_SUCCEEDED;
130   }
131
132   void readErr(const AsyncSocketException& ex) noexcept override {
133     state = STATE_FAILED;
134     exception = ex;
135   }
136
137   void verifyData(const char* expected, size_t expectedLen) const {
138     size_t offset = 0;
139     for (size_t idx = 0; idx < buffers.size(); ++idx) {
140       const auto& buf = buffers[idx];
141       size_t cmpLen = std::min(buf.length, expectedLen - offset);
142       CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
143       CHECK_EQ(cmpLen, buf.length);
144       offset += cmpLen;
145     }
146     CHECK_EQ(offset, expectedLen);
147   }
148
149   size_t dataRead() const {
150     size_t ret = 0;
151     for (const auto& buf : buffers) {
152       ret += buf.length;
153     }
154     return ret;
155   }
156
157   class Buffer {
158    public:
159     Buffer() : buffer(nullptr), length(0) {}
160     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
161
162     void reset() {
163       buffer = nullptr;
164       length = 0;
165     }
166     void allocate(size_t length) {
167       assert(buffer == nullptr);
168       this->buffer = static_cast<char*>(malloc(length));
169       this->length = length;
170     }
171     void free() {
172       ::free(buffer);
173       reset();
174     }
175
176     char* buffer;
177     size_t length;
178   };
179
180   StateEnum state;
181   AsyncSocketException exception;
182   std::vector<Buffer> buffers;
183   Buffer currentBuffer;
184   VoidCallback dataAvailableCallback;
185   const size_t maxBufferSz;
186 };
187
188 class BufferCallback : public AsyncTransport::BufferCallback {
189  public:
190   BufferCallback() : buffered_(false), bufferCleared_(false) {}
191
192   void onEgressBuffered() override { buffered_ = true; }
193
194   void onEgressBufferCleared() override { bufferCleared_ = true; }
195
196   bool hasBuffered() const { return buffered_; }
197
198   bool hasBufferCleared() const { return bufferCleared_; }
199
200  private:
201   bool buffered_{false};
202   bool bufferCleared_{false};
203 };
204
205 class ReadVerifier {
206 };
207
208 class TestServer {
209  public:
210   // Create a TestServer.
211   // This immediately starts listening on an ephemeral port.
212   TestServer()
213     : fd_(-1) {
214     fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
215     if (fd_ < 0) {
216       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
217                                 "failed to create test server socket", errno);
218     }
219     if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
220       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
221                                 "failed to put test server socket in "
222                                 "non-blocking mode", errno);
223     }
224     if (listen(fd_, 10) != 0) {
225       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
226                                 "failed to listen on test server socket",
227                                 errno);
228     }
229
230     address_.setFromLocalAddress(fd_);
231     // The local address will contain 0.0.0.0.
232     // Change it to 127.0.0.1, so it can be used to connect to the server
233     address_.setFromIpPort("127.0.0.1", address_.getPort());
234   }
235
236   // Get the address for connecting to the server
237   const folly::SocketAddress& getAddress() const {
238     return address_;
239   }
240
241   int acceptFD(int timeout=50) {
242     struct pollfd pfd;
243     pfd.fd = fd_;
244     pfd.events = POLLIN;
245     int ret = poll(&pfd, 1, timeout);
246     if (ret == 0) {
247       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
248                                 "test server accept() timed out");
249     } else if (ret < 0) {
250       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
251                                 "test server accept() poll failed", errno);
252     }
253
254     int acceptedFd = ::accept(fd_, nullptr, nullptr);
255     if (acceptedFd < 0) {
256       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
257                                 "test server accept() failed", errno);
258     }
259
260     return acceptedFd;
261   }
262
263   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
264     int fd = acceptFD(timeout);
265     return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
266   }
267
268   std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
269     int fd = acceptFD(timeout);
270     return AsyncSocket::newSocket(evb, fd);
271   }
272
273   /**
274    * Accept a connection, read data from it, and verify that it matches the
275    * data in the specified buffer.
276    */
277   void verifyConnection(const char* buf, size_t len) {
278     // accept a connection
279     std::shared_ptr<BlockingSocket> acceptedSocket = accept();
280     // read the data and compare it to the specified buffer
281     boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
282     acceptedSocket->readAll(readbuf.get(), len);
283     CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
284     // make sure we get EOF next
285     uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
286     CHECK_EQ(bytesRead, 0);
287   }
288
289  private:
290   int fd_;
291   folly::SocketAddress address_;
292 };