Check readCallback before calling handleRead
[folly.git] / folly / io / async / test / AsyncSocketTest.h
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
20
21 #include <boost/scoped_array.hpp>
22 #include <poll.h>
23
24 // This is a test-only header
25 /* using override */
26 using namespace folly;
27
28 enum StateEnum {
29   STATE_WAITING,
30   STATE_SUCCEEDED,
31   STATE_FAILED
32 };
33
34 typedef std::function<void()> VoidCallback;
35
36 class ConnCallback : public AsyncSocket::ConnectCallback {
37  public:
38   ConnCallback()
39     : state(STATE_WAITING)
40     , exception(AsyncSocketException::UNKNOWN, "none") {}
41
42   void connectSuccess() noexcept override {
43     state = STATE_SUCCEEDED;
44     if (successCallback) {
45       successCallback();
46     }
47   }
48
49   void connectErr(const AsyncSocketException& ex) noexcept override {
50     state = STATE_FAILED;
51     exception = ex;
52     if (errorCallback) {
53       errorCallback();
54     }
55   }
56
57   StateEnum state;
58   AsyncSocketException exception;
59   VoidCallback successCallback;
60   VoidCallback errorCallback;
61 };
62
63 class WriteCallback : public AsyncTransportWrapper::WriteCallback {
64  public:
65   WriteCallback()
66     : state(STATE_WAITING)
67     , bytesWritten(0)
68     , exception(AsyncSocketException::UNKNOWN, "none") {}
69
70   void writeSuccess() noexcept override {
71     state = STATE_SUCCEEDED;
72     if (successCallback) {
73       successCallback();
74     }
75   }
76
77   void writeErr(size_t bytesWritten,
78                 const AsyncSocketException& ex) noexcept override {
79     state = STATE_FAILED;
80     this->bytesWritten = bytesWritten;
81     exception = ex;
82     if (errorCallback) {
83       errorCallback();
84     }
85   }
86
87   StateEnum state;
88   size_t bytesWritten;
89   AsyncSocketException exception;
90   VoidCallback successCallback;
91   VoidCallback errorCallback;
92 };
93
94 class ReadCallback : public AsyncTransportWrapper::ReadCallback {
95  public:
96   explicit ReadCallback(size_t _maxBufferSz = 4096)
97     : state(STATE_WAITING)
98     , exception(AsyncSocketException::UNKNOWN, "none")
99     , buffers()
100     , maxBufferSz(_maxBufferSz) {}
101
102   ~ReadCallback() {
103     for (std::vector<Buffer>::iterator it = buffers.begin();
104          it != buffers.end();
105          ++it) {
106       it->free();
107     }
108     currentBuffer.free();
109   }
110
111   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
112     if (!currentBuffer.buffer) {
113       currentBuffer.allocate(maxBufferSz);
114     }
115     *bufReturn = currentBuffer.buffer;
116     *lenReturn = currentBuffer.length;
117   }
118
119   void readDataAvailable(size_t len) noexcept override {
120     currentBuffer.length = len;
121     buffers.push_back(currentBuffer);
122     currentBuffer.reset();
123     if (dataAvailableCallback) {
124       dataAvailableCallback();
125     }
126   }
127
128   void readEOF() noexcept override {
129     state = STATE_SUCCEEDED;
130   }
131
132   void readErr(const AsyncSocketException& ex) noexcept override {
133     state = STATE_FAILED;
134     exception = ex;
135   }
136
137   void verifyData(const char* expected, size_t expectedLen) const {
138     size_t offset = 0;
139     for (size_t idx = 0; idx < buffers.size(); ++idx) {
140       const auto& buf = buffers[idx];
141       size_t cmpLen = std::min(buf.length, expectedLen - offset);
142       CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
143       CHECK_EQ(cmpLen, buf.length);
144       offset += cmpLen;
145     }
146     CHECK_EQ(offset, expectedLen);
147   }
148
149   size_t dataRead() const {
150     size_t ret = 0;
151     for (const auto& buf : buffers) {
152       ret += buf.length;
153     }
154     return ret;
155   }
156
157   class Buffer {
158    public:
159     Buffer() : buffer(nullptr), length(0) {}
160     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
161
162     void reset() {
163       buffer = nullptr;
164       length = 0;
165     }
166     void allocate(size_t length) {
167       assert(buffer == nullptr);
168       this->buffer = static_cast<char*>(malloc(length));
169       this->length = length;
170     }
171     void free() {
172       ::free(buffer);
173       reset();
174     }
175
176     char* buffer;
177     size_t length;
178   };
179
180   StateEnum state;
181   AsyncSocketException exception;
182   std::vector<Buffer> buffers;
183   Buffer currentBuffer;
184   VoidCallback dataAvailableCallback;
185   const size_t maxBufferSz;
186 };
187
188 class ReadVerifier {
189 };
190
191 class TestServer {
192  public:
193   // Create a TestServer.
194   // This immediately starts listening on an ephemeral port.
195   TestServer()
196     : fd_(-1) {
197     fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
198     if (fd_ < 0) {
199       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
200                                 "failed to create test server socket", errno);
201     }
202     if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
203       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
204                                 "failed to put test server socket in "
205                                 "non-blocking mode", errno);
206     }
207     if (listen(fd_, 10) != 0) {
208       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
209                                 "failed to listen on test server socket",
210                                 errno);
211     }
212
213     address_.setFromLocalAddress(fd_);
214     // The local address will contain 0.0.0.0.
215     // Change it to 127.0.0.1, so it can be used to connect to the server
216     address_.setFromIpPort("127.0.0.1", address_.getPort());
217   }
218
219   // Get the address for connecting to the server
220   const folly::SocketAddress& getAddress() const {
221     return address_;
222   }
223
224   int acceptFD(int timeout=50) {
225     struct pollfd pfd;
226     pfd.fd = fd_;
227     pfd.events = POLLIN;
228     int ret = poll(&pfd, 1, timeout);
229     if (ret == 0) {
230       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
231                                 "test server accept() timed out");
232     } else if (ret < 0) {
233       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
234                                 "test server accept() poll failed", errno);
235     }
236
237     int acceptedFd = ::accept(fd_, nullptr, nullptr);
238     if (acceptedFd < 0) {
239       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
240                                 "test server accept() failed", errno);
241     }
242
243     return acceptedFd;
244   }
245
246   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
247     int fd = acceptFD(timeout);
248     return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
249   }
250
251   std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
252     int fd = acceptFD(timeout);
253     return AsyncSocket::newSocket(evb, fd);
254   }
255
256   /**
257    * Accept a connection, read data from it, and verify that it matches the
258    * data in the specified buffer.
259    */
260   void verifyConnection(const char* buf, size_t len) {
261     // accept a connection
262     std::shared_ptr<BlockingSocket> acceptedSocket = accept();
263     // read the data and compare it to the specified buffer
264     boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
265     acceptedSocket->readAll(readbuf.get(), len);
266     CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
267     // make sure we get EOF next
268     uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
269     CHECK_EQ(bytesRead, 0);
270   }
271
272  private:
273   int fd_;
274   folly::SocketAddress address_;
275 };