2 * Copyright 2015 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
21 #include <boost/scoped_array.hpp>
24 // This is a test-only header
26 using namespace folly;
34 typedef std::function<void()> VoidCallback;
36 class ConnCallback : public AsyncSocket::ConnectCallback {
39 : state(STATE_WAITING)
40 , exception(AsyncSocketException::UNKNOWN, "none") {}
42 void connectSuccess() noexcept override {
43 state = STATE_SUCCEEDED;
44 if (successCallback) {
49 void connectErr(const AsyncSocketException& ex) noexcept override {
58 AsyncSocketException exception;
59 VoidCallback successCallback;
60 VoidCallback errorCallback;
63 class WriteCallback : public AsyncTransportWrapper::WriteCallback {
66 : state(STATE_WAITING)
68 , exception(AsyncSocketException::UNKNOWN, "none") {}
70 void writeSuccess() noexcept override {
71 state = STATE_SUCCEEDED;
72 if (successCallback) {
77 void writeErr(size_t bytesWritten,
78 const AsyncSocketException& ex) noexcept override {
80 this->bytesWritten = bytesWritten;
89 AsyncSocketException exception;
90 VoidCallback successCallback;
91 VoidCallback errorCallback;
94 class ReadCallback : public AsyncTransportWrapper::ReadCallback {
96 explicit ReadCallback(size_t _maxBufferSz = 4096)
97 : state(STATE_WAITING)
98 , exception(AsyncSocketException::UNKNOWN, "none")
100 , maxBufferSz(_maxBufferSz) {}
103 for (std::vector<Buffer>::iterator it = buffers.begin();
108 currentBuffer.free();
111 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
112 if (!currentBuffer.buffer) {
113 currentBuffer.allocate(maxBufferSz);
115 *bufReturn = currentBuffer.buffer;
116 *lenReturn = currentBuffer.length;
119 void readDataAvailable(size_t len) noexcept override {
120 currentBuffer.length = len;
121 buffers.push_back(currentBuffer);
122 currentBuffer.reset();
123 if (dataAvailableCallback) {
124 dataAvailableCallback();
128 void readEOF() noexcept override {
129 state = STATE_SUCCEEDED;
132 void readErr(const AsyncSocketException& ex) noexcept override {
133 state = STATE_FAILED;
137 void verifyData(const char* expected, size_t expectedLen) const {
139 for (size_t idx = 0; idx < buffers.size(); ++idx) {
140 const auto& buf = buffers[idx];
141 size_t cmpLen = std::min(buf.length, expectedLen - offset);
142 CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
143 CHECK_EQ(cmpLen, buf.length);
146 CHECK_EQ(offset, expectedLen);
149 size_t dataRead() const {
151 for (const auto& buf : buffers) {
159 Buffer() : buffer(nullptr), length(0) {}
160 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
166 void allocate(size_t length) {
167 assert(buffer == nullptr);
168 this->buffer = static_cast<char*>(malloc(length));
169 this->length = length;
181 AsyncSocketException exception;
182 std::vector<Buffer> buffers;
183 Buffer currentBuffer;
184 VoidCallback dataAvailableCallback;
185 const size_t maxBufferSz;
188 class BufferCallback : public AsyncTransport::BufferCallback {
190 BufferCallback() : buffered_(false), bufferCleared_(false) {}
192 void onEgressBuffered() override { buffered_ = true; }
194 void onEgressBufferCleared() override { bufferCleared_ = true; }
196 bool hasBuffered() const { return buffered_; }
198 bool hasBufferCleared() const { return bufferCleared_; }
201 bool buffered_{false};
202 bool bufferCleared_{false};
210 // Create a TestServer.
211 // This immediately starts listening on an ephemeral port.
214 fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
216 throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
217 "failed to create test server socket", errno);
219 if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
220 throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
221 "failed to put test server socket in "
222 "non-blocking mode", errno);
224 if (listen(fd_, 10) != 0) {
225 throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
226 "failed to listen on test server socket",
230 address_.setFromLocalAddress(fd_);
231 // The local address will contain 0.0.0.0.
232 // Change it to 127.0.0.1, so it can be used to connect to the server
233 address_.setFromIpPort("127.0.0.1", address_.getPort());
236 // Get the address for connecting to the server
237 const folly::SocketAddress& getAddress() const {
241 int acceptFD(int timeout=50) {
245 int ret = poll(&pfd, 1, timeout);
247 throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
248 "test server accept() timed out");
249 } else if (ret < 0) {
250 throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
251 "test server accept() poll failed", errno);
254 int acceptedFd = ::accept(fd_, nullptr, nullptr);
255 if (acceptedFd < 0) {
256 throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
257 "test server accept() failed", errno);
263 std::shared_ptr<BlockingSocket> accept(int timeout=50) {
264 int fd = acceptFD(timeout);
265 return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
268 std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
269 int fd = acceptFD(timeout);
270 return AsyncSocket::newSocket(evb, fd);
274 * Accept a connection, read data from it, and verify that it matches the
275 * data in the specified buffer.
277 void verifyConnection(const char* buf, size_t len) {
278 // accept a connection
279 std::shared_ptr<BlockingSocket> acceptedSocket = accept();
280 // read the data and compare it to the specified buffer
281 boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
282 acceptedSocket->readAll(readbuf.get(), len);
283 CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
284 // make sure we get EOF next
285 uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
286 CHECK_EQ(bytesRead, 0);
291 folly::SocketAddress address_;