2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
21 #include <boost/scoped_array.hpp>
30 typedef std::function<void()> VoidCallback;
32 class ConnCallback : public folly::AsyncSocket::ConnectCallback {
35 : state(STATE_WAITING),
36 exception(folly::AsyncSocketException::UNKNOWN, "none") {}
38 void connectSuccess() noexcept override {
39 state = STATE_SUCCEEDED;
40 if (successCallback) {
45 void connectErr(const folly::AsyncSocketException& ex) noexcept override {
54 folly::AsyncSocketException exception;
55 VoidCallback successCallback;
56 VoidCallback errorCallback;
59 class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
62 : state(STATE_WAITING),
64 exception(folly::AsyncSocketException::UNKNOWN, "none") {}
66 void writeSuccess() noexcept override {
67 state = STATE_SUCCEEDED;
68 if (successCallback) {
73 void writeErr(size_t bytesWritten,
74 const folly::AsyncSocketException& ex) noexcept override {
76 this->bytesWritten = bytesWritten;
85 folly::AsyncSocketException exception;
86 VoidCallback successCallback;
87 VoidCallback errorCallback;
90 class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
92 explicit ReadCallback(size_t _maxBufferSz = 4096)
93 : state(STATE_WAITING),
94 exception(folly::AsyncSocketException::UNKNOWN, "none"),
96 maxBufferSz(_maxBufferSz) {}
99 for (std::vector<Buffer>::iterator it = buffers.begin();
104 currentBuffer.free();
107 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
108 if (!currentBuffer.buffer) {
109 currentBuffer.allocate(maxBufferSz);
111 *bufReturn = currentBuffer.buffer;
112 *lenReturn = currentBuffer.length;
115 void readDataAvailable(size_t len) noexcept override {
116 currentBuffer.length = len;
117 buffers.push_back(currentBuffer);
118 currentBuffer.reset();
119 if (dataAvailableCallback) {
120 dataAvailableCallback();
124 void readEOF() noexcept override {
125 state = STATE_SUCCEEDED;
128 void readErr(const folly::AsyncSocketException& ex) noexcept override {
129 state = STATE_FAILED;
133 void verifyData(const char* expected, size_t expectedLen) const {
135 for (size_t idx = 0; idx < buffers.size(); ++idx) {
136 const auto& buf = buffers[idx];
137 size_t cmpLen = std::min(buf.length, expectedLen - offset);
138 CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
139 CHECK_EQ(cmpLen, buf.length);
142 CHECK_EQ(offset, expectedLen);
145 size_t dataRead() const {
147 for (const auto& buf : buffers) {
155 Buffer() : buffer(nullptr), length(0) {}
156 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
162 void allocate(size_t length) {
163 assert(buffer == nullptr);
164 this->buffer = static_cast<char*>(malloc(length));
165 this->length = length;
177 folly::AsyncSocketException exception;
178 std::vector<Buffer> buffers;
179 Buffer currentBuffer;
180 VoidCallback dataAvailableCallback;
181 const size_t maxBufferSz;
184 class BufferCallback : public folly::AsyncTransport::BufferCallback {
186 BufferCallback() : buffered_(false), bufferCleared_(false) {}
188 void onEgressBuffered() override { buffered_ = true; }
190 void onEgressBufferCleared() override { bufferCleared_ = true; }
192 bool hasBuffered() const { return buffered_; }
194 bool hasBufferCleared() const { return bufferCleared_; }
197 bool buffered_{false};
198 bool bufferCleared_{false};
206 // Create a TestServer.
207 // This immediately starts listening on an ephemeral port.
210 fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
212 throw folly::AsyncSocketException(
213 folly::AsyncSocketException::INTERNAL_ERROR,
214 "failed to create test server socket",
217 if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
218 throw folly::AsyncSocketException(
219 folly::AsyncSocketException::INTERNAL_ERROR,
220 "failed to put test server socket in "
224 if (listen(fd_, 10) != 0) {
225 throw folly::AsyncSocketException(
226 folly::AsyncSocketException::INTERNAL_ERROR,
227 "failed to listen on test server socket",
231 address_.setFromLocalAddress(fd_);
232 // The local address will contain 0.0.0.0.
233 // Change it to 127.0.0.1, so it can be used to connect to the server
234 address_.setFromIpPort("127.0.0.1", address_.getPort());
237 // Get the address for connecting to the server
238 const folly::SocketAddress& getAddress() const {
242 int acceptFD(int timeout=50) {
246 int ret = poll(&pfd, 1, timeout);
248 throw folly::AsyncSocketException(
249 folly::AsyncSocketException::INTERNAL_ERROR,
250 "test server accept() timed out");
251 } else if (ret < 0) {
252 throw folly::AsyncSocketException(
253 folly::AsyncSocketException::INTERNAL_ERROR,
254 "test server accept() poll failed",
258 int acceptedFd = ::accept(fd_, nullptr, nullptr);
259 if (acceptedFd < 0) {
260 throw folly::AsyncSocketException(
261 folly::AsyncSocketException::INTERNAL_ERROR,
262 "test server accept() failed",
269 std::shared_ptr<BlockingSocket> accept(int timeout=50) {
270 int fd = acceptFD(timeout);
271 return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
274 std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,
276 int fd = acceptFD(timeout);
277 return folly::AsyncSocket::newSocket(evb, fd);
281 * Accept a connection, read data from it, and verify that it matches the
282 * data in the specified buffer.
284 void verifyConnection(const char* buf, size_t len) {
285 // accept a connection
286 std::shared_ptr<BlockingSocket> acceptedSocket = accept();
287 // read the data and compare it to the specified buffer
288 boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
289 acceptedSocket->readAll(readbuf.get(), len);
290 CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
291 // make sure we get EOF next
292 uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
293 CHECK_EQ(bytesRead, 0);
298 folly::SocketAddress address_;