AsyncSocket::writeRequest() and its first user wangle::FileRegion
[folly.git] / folly / io / async / test / AsyncSocketTest.h
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
20
21 #include <boost/scoped_array.hpp>
22 #include <poll.h>
23
24 // This is a test-only header
25 /* using override */
26 using namespace folly;
27
28 enum StateEnum {
29   STATE_WAITING,
30   STATE_SUCCEEDED,
31   STATE_FAILED
32 };
33
34 typedef std::function<void()> VoidCallback;
35
36 class ConnCallback : public AsyncSocket::ConnectCallback {
37  public:
38   ConnCallback()
39     : state(STATE_WAITING)
40     , exception(AsyncSocketException::UNKNOWN, "none") {}
41
42   void connectSuccess() noexcept override {
43     state = STATE_SUCCEEDED;
44     if (successCallback) {
45       successCallback();
46     }
47   }
48
49   void connectErr(const AsyncSocketException& ex) noexcept override {
50     state = STATE_FAILED;
51     exception = ex;
52     if (errorCallback) {
53       errorCallback();
54     }
55   }
56
57   StateEnum state;
58   AsyncSocketException exception;
59   VoidCallback successCallback;
60   VoidCallback errorCallback;
61 };
62
63 class WriteCallback : public AsyncTransportWrapper::WriteCallback {
64  public:
65   WriteCallback()
66     : state(STATE_WAITING)
67     , bytesWritten(0)
68     , exception(AsyncSocketException::UNKNOWN, "none") {}
69
70   void writeSuccess() noexcept override {
71     state = STATE_SUCCEEDED;
72     if (successCallback) {
73       successCallback();
74     }
75   }
76
77   void writeErr(size_t bytesWritten,
78                 const AsyncSocketException& ex) noexcept override {
79     state = STATE_FAILED;
80     this->bytesWritten = bytesWritten;
81     exception = ex;
82     if (errorCallback) {
83       errorCallback();
84     }
85   }
86
87   StateEnum state;
88   size_t bytesWritten;
89   AsyncSocketException exception;
90   VoidCallback successCallback;
91   VoidCallback errorCallback;
92 };
93
94 class ReadCallback : public AsyncTransportWrapper::ReadCallback {
95  public:
96   ReadCallback()
97     : state(STATE_WAITING)
98     , exception(AsyncSocketException::UNKNOWN, "none")
99     , buffers() {}
100
101   ~ReadCallback() {
102     for (std::vector<Buffer>::iterator it = buffers.begin();
103          it != buffers.end();
104          ++it) {
105       it->free();
106     }
107     currentBuffer.free();
108   }
109
110   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
111     if (!currentBuffer.buffer) {
112       currentBuffer.allocate(4096);
113     }
114     *bufReturn = currentBuffer.buffer;
115     *lenReturn = currentBuffer.length;
116   }
117
118   void readDataAvailable(size_t len) noexcept override {
119     currentBuffer.length = len;
120     buffers.push_back(currentBuffer);
121     currentBuffer.reset();
122     if (dataAvailableCallback) {
123       dataAvailableCallback();
124     }
125   }
126
127   void readEOF() noexcept override {
128     state = STATE_SUCCEEDED;
129   }
130
131   void readErr(const AsyncSocketException& ex) noexcept override {
132     state = STATE_FAILED;
133     exception = ex;
134   }
135
136   void verifyData(const char* expected, size_t expectedLen) const {
137     size_t offset = 0;
138     for (size_t idx = 0; idx < buffers.size(); ++idx) {
139       const auto& buf = buffers[idx];
140       size_t cmpLen = std::min(buf.length, expectedLen - offset);
141       CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
142       CHECK_EQ(cmpLen, buf.length);
143       offset += cmpLen;
144     }
145     CHECK_EQ(offset, expectedLen);
146   }
147
148   class Buffer {
149    public:
150     Buffer() : buffer(nullptr), length(0) {}
151     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
152
153     void reset() {
154       buffer = nullptr;
155       length = 0;
156     }
157     void allocate(size_t length) {
158       assert(buffer == nullptr);
159       this->buffer = static_cast<char*>(malloc(length));
160       this->length = length;
161     }
162     void free() {
163       ::free(buffer);
164       reset();
165     }
166
167     char* buffer;
168     size_t length;
169   };
170
171   StateEnum state;
172   AsyncSocketException exception;
173   std::vector<Buffer> buffers;
174   Buffer currentBuffer;
175   VoidCallback dataAvailableCallback;
176 };
177
178 class ReadVerifier {
179 };
180
181 class TestServer {
182  public:
183   // Create a TestServer.
184   // This immediately starts listening on an ephemeral port.
185   TestServer()
186     : fd_(-1) {
187     fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
188     if (fd_ < 0) {
189       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
190                                 "failed to create test server socket", errno);
191     }
192     if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
193       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
194                                 "failed to put test server socket in "
195                                 "non-blocking mode", errno);
196     }
197     if (listen(fd_, 10) != 0) {
198       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
199                                 "failed to listen on test server socket",
200                                 errno);
201     }
202
203     address_.setFromLocalAddress(fd_);
204     // The local address will contain 0.0.0.0.
205     // Change it to 127.0.0.1, so it can be used to connect to the server
206     address_.setFromIpPort("127.0.0.1", address_.getPort());
207   }
208
209   // Get the address for connecting to the server
210   const folly::SocketAddress& getAddress() const {
211     return address_;
212   }
213
214   int acceptFD(int timeout=50) {
215     struct pollfd pfd;
216     pfd.fd = fd_;
217     pfd.events = POLLIN;
218     int ret = poll(&pfd, 1, timeout);
219     if (ret == 0) {
220       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
221                                 "test server accept() timed out");
222     } else if (ret < 0) {
223       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
224                                 "test server accept() poll failed", errno);
225     }
226
227     int acceptedFd = ::accept(fd_, nullptr, nullptr);
228     if (acceptedFd < 0) {
229       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
230                                 "test server accept() failed", errno);
231     }
232
233     return acceptedFd;
234   }
235
236   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
237     int fd = acceptFD(timeout);
238     return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
239   }
240
241   std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
242     int fd = acceptFD(timeout);
243     return AsyncSocket::newSocket(evb, fd);
244   }
245
246   /**
247    * Accept a connection, read data from it, and verify that it matches the
248    * data in the specified buffer.
249    */
250   void verifyConnection(const char* buf, size_t len) {
251     // accept a connection
252     std::shared_ptr<BlockingSocket> acceptedSocket = accept();
253     // read the data and compare it to the specified buffer
254     boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
255     acceptedSocket->readAll(readbuf.get(), len);
256     CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
257     // make sure we get EOF next
258     uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
259     CHECK_EQ(bytesRead, 0);
260   }
261
262  private:
263   int fd_;
264   folly::SocketAddress address_;
265 };