fe03fa6f9412855580640f5e20d45e8fa1b44896
[folly.git] / folly / io / async / test / AsyncSocketTest.h
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
20
21 #include <boost/scoped_array.hpp>
22 #include <poll.h>
23
24 enum StateEnum {
25   STATE_WAITING,
26   STATE_SUCCEEDED,
27   STATE_FAILED
28 };
29
30 typedef std::function<void()> VoidCallback;
31
32 class ConnCallback : public folly::AsyncSocket::ConnectCallback {
33  public:
34   ConnCallback()
35       : state(STATE_WAITING),
36         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
37
38   void connectSuccess() noexcept override {
39     state = STATE_SUCCEEDED;
40     if (successCallback) {
41       successCallback();
42     }
43   }
44
45   void connectErr(const folly::AsyncSocketException& ex) noexcept override {
46     state = STATE_FAILED;
47     exception = ex;
48     if (errorCallback) {
49       errorCallback();
50     }
51   }
52
53   StateEnum state;
54   folly::AsyncSocketException exception;
55   VoidCallback successCallback;
56   VoidCallback errorCallback;
57 };
58
59 class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
60  public:
61   WriteCallback()
62       : state(STATE_WAITING),
63         bytesWritten(0),
64         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
65
66   void writeSuccess() noexcept override {
67     state = STATE_SUCCEEDED;
68     if (successCallback) {
69       successCallback();
70     }
71   }
72
73   void writeErr(size_t bytesWritten,
74                 const folly::AsyncSocketException& ex) noexcept override {
75     state = STATE_FAILED;
76     this->bytesWritten = bytesWritten;
77     exception = ex;
78     if (errorCallback) {
79       errorCallback();
80     }
81   }
82
83   StateEnum state;
84   size_t bytesWritten;
85   folly::AsyncSocketException exception;
86   VoidCallback successCallback;
87   VoidCallback errorCallback;
88 };
89
90 class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
91  public:
92   explicit ReadCallback(size_t _maxBufferSz = 4096)
93       : state(STATE_WAITING),
94         exception(folly::AsyncSocketException::UNKNOWN, "none"),
95         buffers(),
96         maxBufferSz(_maxBufferSz) {}
97
98   ~ReadCallback() {
99     for (std::vector<Buffer>::iterator it = buffers.begin();
100          it != buffers.end();
101          ++it) {
102       it->free();
103     }
104     currentBuffer.free();
105   }
106
107   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
108     if (!currentBuffer.buffer) {
109       currentBuffer.allocate(maxBufferSz);
110     }
111     *bufReturn = currentBuffer.buffer;
112     *lenReturn = currentBuffer.length;
113   }
114
115   void readDataAvailable(size_t len) noexcept override {
116     currentBuffer.length = len;
117     buffers.push_back(currentBuffer);
118     currentBuffer.reset();
119     if (dataAvailableCallback) {
120       dataAvailableCallback();
121     }
122   }
123
124   void readEOF() noexcept override {
125     state = STATE_SUCCEEDED;
126   }
127
128   void readErr(const folly::AsyncSocketException& ex) noexcept override {
129     state = STATE_FAILED;
130     exception = ex;
131   }
132
133   void verifyData(const char* expected, size_t expectedLen) const {
134     size_t offset = 0;
135     for (size_t idx = 0; idx < buffers.size(); ++idx) {
136       const auto& buf = buffers[idx];
137       size_t cmpLen = std::min(buf.length, expectedLen - offset);
138       CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
139       CHECK_EQ(cmpLen, buf.length);
140       offset += cmpLen;
141     }
142     CHECK_EQ(offset, expectedLen);
143   }
144
145   size_t dataRead() const {
146     size_t ret = 0;
147     for (const auto& buf : buffers) {
148       ret += buf.length;
149     }
150     return ret;
151   }
152
153   class Buffer {
154    public:
155     Buffer() : buffer(nullptr), length(0) {}
156     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
157
158     void reset() {
159       buffer = nullptr;
160       length = 0;
161     }
162     void allocate(size_t length) {
163       assert(buffer == nullptr);
164       this->buffer = static_cast<char*>(malloc(length));
165       this->length = length;
166     }
167     void free() {
168       ::free(buffer);
169       reset();
170     }
171
172     char* buffer;
173     size_t length;
174   };
175
176   StateEnum state;
177   folly::AsyncSocketException exception;
178   std::vector<Buffer> buffers;
179   Buffer currentBuffer;
180   VoidCallback dataAvailableCallback;
181   const size_t maxBufferSz;
182 };
183
184 class BufferCallback : public folly::AsyncTransport::BufferCallback {
185  public:
186   BufferCallback() : buffered_(false), bufferCleared_(false) {}
187
188   void onEgressBuffered() override { buffered_ = true; }
189
190   void onEgressBufferCleared() override { bufferCleared_ = true; }
191
192   bool hasBuffered() const { return buffered_; }
193
194   bool hasBufferCleared() const { return bufferCleared_; }
195
196  private:
197   bool buffered_{false};
198   bool bufferCleared_{false};
199 };
200
201 class ReadVerifier {
202 };
203
204 class TestServer {
205  public:
206   // Create a TestServer.
207   // This immediately starts listening on an ephemeral port.
208   TestServer()
209     : fd_(-1) {
210     fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
211     if (fd_ < 0) {
212       throw folly::AsyncSocketException(
213           folly::AsyncSocketException::INTERNAL_ERROR,
214           "failed to create test server socket",
215           errno);
216     }
217     if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
218       throw folly::AsyncSocketException(
219           folly::AsyncSocketException::INTERNAL_ERROR,
220           "failed to put test server socket in "
221           "non-blocking mode",
222           errno);
223     }
224     if (listen(fd_, 10) != 0) {
225       throw folly::AsyncSocketException(
226           folly::AsyncSocketException::INTERNAL_ERROR,
227           "failed to listen on test server socket",
228           errno);
229     }
230
231     address_.setFromLocalAddress(fd_);
232     // The local address will contain 0.0.0.0.
233     // Change it to 127.0.0.1, so it can be used to connect to the server
234     address_.setFromIpPort("127.0.0.1", address_.getPort());
235   }
236
237   // Get the address for connecting to the server
238   const folly::SocketAddress& getAddress() const {
239     return address_;
240   }
241
242   int acceptFD(int timeout=50) {
243     struct pollfd pfd;
244     pfd.fd = fd_;
245     pfd.events = POLLIN;
246     int ret = poll(&pfd, 1, timeout);
247     if (ret == 0) {
248       throw folly::AsyncSocketException(
249           folly::AsyncSocketException::INTERNAL_ERROR,
250           "test server accept() timed out");
251     } else if (ret < 0) {
252       throw folly::AsyncSocketException(
253           folly::AsyncSocketException::INTERNAL_ERROR,
254           "test server accept() poll failed",
255           errno);
256     }
257
258     int acceptedFd = ::accept(fd_, nullptr, nullptr);
259     if (acceptedFd < 0) {
260       throw folly::AsyncSocketException(
261           folly::AsyncSocketException::INTERNAL_ERROR,
262           "test server accept() failed",
263           errno);
264     }
265
266     return acceptedFd;
267   }
268
269   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
270     int fd = acceptFD(timeout);
271     return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
272   }
273
274   std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,
275                                                   int timeout = 50) {
276     int fd = acceptFD(timeout);
277     return folly::AsyncSocket::newSocket(evb, fd);
278   }
279
280   /**
281    * Accept a connection, read data from it, and verify that it matches the
282    * data in the specified buffer.
283    */
284   void verifyConnection(const char* buf, size_t len) {
285     // accept a connection
286     std::shared_ptr<BlockingSocket> acceptedSocket = accept();
287     // read the data and compare it to the specified buffer
288     boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
289     acceptedSocket->readAll(readbuf.get(), len);
290     CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
291     // make sure we get EOF next
292     uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
293     CHECK_EQ(bytesRead, 0);
294   }
295
296  private:
297   int fd_;
298   folly::SocketAddress address_;
299 };