2 * Copyright 2017 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
20 #include <folly/portability/Sockets.h>
22 #include <boost/scoped_array.hpp>
30 typedef std::function<void()> VoidCallback;
32 class ConnCallback : public folly::AsyncSocket::ConnectCallback {
35 : state(STATE_WAITING),
36 exception(folly::AsyncSocketException::UNKNOWN, "none") {}
38 void connectSuccess() noexcept override {
39 state = STATE_SUCCEEDED;
40 if (successCallback) {
45 void connectErr(const folly::AsyncSocketException& ex) noexcept override {
54 folly::AsyncSocketException exception;
55 VoidCallback successCallback;
56 VoidCallback errorCallback;
59 class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
62 : state(STATE_WAITING),
64 exception(folly::AsyncSocketException::UNKNOWN, "none") {}
66 void writeSuccess() noexcept override {
67 state = STATE_SUCCEEDED;
68 if (successCallback) {
73 void writeErr(size_t nBytesWritten,
74 const folly::AsyncSocketException& ex) noexcept override {
75 LOG(ERROR) << ex.what();
77 this->bytesWritten = nBytesWritten;
86 folly::AsyncSocketException exception;
87 VoidCallback successCallback;
88 VoidCallback errorCallback;
91 class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
93 explicit ReadCallback(size_t _maxBufferSz = 4096)
94 : state(STATE_WAITING),
95 exception(folly::AsyncSocketException::UNKNOWN, "none"),
97 maxBufferSz(_maxBufferSz) {}
100 for (std::vector<Buffer>::iterator it = buffers.begin();
105 currentBuffer.free();
108 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
109 if (!currentBuffer.buffer) {
110 currentBuffer.allocate(maxBufferSz);
112 *bufReturn = currentBuffer.buffer;
113 *lenReturn = currentBuffer.length;
116 void readDataAvailable(size_t len) noexcept override {
117 currentBuffer.length = len;
118 buffers.push_back(currentBuffer);
119 currentBuffer.reset();
120 if (dataAvailableCallback) {
121 dataAvailableCallback();
125 void readEOF() noexcept override {
126 state = STATE_SUCCEEDED;
129 void readErr(const folly::AsyncSocketException& ex) noexcept override {
130 state = STATE_FAILED;
134 void verifyData(const char* expected, size_t expectedLen) const {
136 for (size_t idx = 0; idx < buffers.size(); ++idx) {
137 const auto& buf = buffers[idx];
138 size_t cmpLen = std::min(buf.length, expectedLen - offset);
139 CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
140 CHECK_EQ(cmpLen, buf.length);
143 CHECK_EQ(offset, expectedLen);
146 size_t dataRead() const {
148 for (const auto& buf : buffers) {
156 Buffer() : buffer(nullptr), length(0) {}
157 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
163 void allocate(size_t len) {
164 assert(buffer == nullptr);
165 this->buffer = static_cast<char*>(malloc(len));
178 folly::AsyncSocketException exception;
179 std::vector<Buffer> buffers;
180 Buffer currentBuffer;
181 VoidCallback dataAvailableCallback;
182 const size_t maxBufferSz;
185 class BufferCallback : public folly::AsyncTransport::BufferCallback {
187 BufferCallback() : buffered_(false), bufferCleared_(false) {}
189 void onEgressBuffered() override { buffered_ = true; }
191 void onEgressBufferCleared() override { bufferCleared_ = true; }
193 bool hasBuffered() const { return buffered_; }
195 bool hasBufferCleared() const { return bufferCleared_; }
198 bool buffered_{false};
199 bool bufferCleared_{false};
205 class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
207 TestErrMessageCallback()
208 : exception_(folly::AsyncSocketException::UNKNOWN, "none")
211 void errMessage(const cmsghdr& cmsg) noexcept override {
212 if (cmsg.cmsg_level == SOL_SOCKET &&
213 cmsg.cmsg_type == SCM_TIMESTAMPING) {
214 gotTimestamp_ = true;
216 (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
217 (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
222 void errMessageError(
223 const folly::AsyncSocketException& ex) noexcept override {
227 folly::AsyncSocketException exception_;
228 bool gotTimestamp_{false};
229 bool gotByteSeq_{false};
232 class TestSendMsgParamsCallback :
233 public folly::AsyncSocket::SendMsgParamsCallback {
235 TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
237 writeFlags_(folly::WriteFlags::NONE),
240 queriedFlags_(false),
244 void reset(int flags) {
246 writeFlags_ = folly::WriteFlags::NONE;
247 queriedFlags_ = false;
248 queriedData_ = false;
251 int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
253 queriedFlags_ = true;
254 if (writeFlags_ == folly::WriteFlags::NONE) {
257 assert(flags == writeFlags_);
262 void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
264 if (writeFlags_ == folly::WriteFlags::NONE) {
267 assert(flags == writeFlags_);
269 assert(data != nullptr);
270 memcpy(data, data_, dataSize_);
273 uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
274 if (writeFlags_ == folly::WriteFlags::NONE) {
277 assert(flags == writeFlags_);
283 folly::WriteFlags writeFlags_;
292 // Create a TestServer.
293 // This immediately starts listening on an ephemeral port.
294 explicit TestServer(bool enableTFO = false, int bufSize = -1) : fd_(-1) {
295 namespace fsp = folly::portability::sockets;
296 fd_ = fsp::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
298 throw folly::AsyncSocketException(
299 folly::AsyncSocketException::INTERNAL_ERROR,
300 "failed to create test server socket",
303 if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
304 throw folly::AsyncSocketException(
305 folly::AsyncSocketException::INTERNAL_ERROR,
306 "failed to put test server socket in "
312 folly::detail::tfo_enable(fd_, 100);
316 struct addrinfo hints, *res;
317 memset(&hints, 0, sizeof(hints));
318 hints.ai_family = AF_INET;
319 hints.ai_socktype = SOCK_STREAM;
320 hints.ai_flags = AI_PASSIVE;
322 if (getaddrinfo(nullptr, "0", &hints, &res)) {
323 throw folly::AsyncSocketException(
324 folly::AsyncSocketException::INTERNAL_ERROR,
325 "Attempted to bind address to socket with "
335 setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize));
336 setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize));
339 if (bind(fd_, res->ai_addr, res->ai_addrlen)) {
340 throw folly::AsyncSocketException(
341 folly::AsyncSocketException::INTERNAL_ERROR,
342 "failed to bind to async server socket for port 10",
346 if (listen(fd_, 10) != 0) {
347 throw folly::AsyncSocketException(
348 folly::AsyncSocketException::INTERNAL_ERROR,
349 "failed to listen on test server socket",
353 address_.setFromLocalAddress(fd_);
354 // The local address will contain 0.0.0.0.
355 // Change it to 127.0.0.1, so it can be used to connect to the server
356 address_.setFromIpPort("127.0.0.1", address_.getPort());
365 // Get the address for connecting to the server
366 const folly::SocketAddress& getAddress() const {
370 int acceptFD(int timeout=50) {
371 namespace fsp = folly::portability::sockets;
375 int ret = poll(&pfd, 1, timeout);
377 throw folly::AsyncSocketException(
378 folly::AsyncSocketException::INTERNAL_ERROR,
379 "test server accept() timed out");
380 } else if (ret < 0) {
381 throw folly::AsyncSocketException(
382 folly::AsyncSocketException::INTERNAL_ERROR,
383 "test server accept() poll failed",
387 int acceptedFd = fsp::accept(fd_, nullptr, nullptr);
388 if (acceptedFd < 0) {
389 throw folly::AsyncSocketException(
390 folly::AsyncSocketException::INTERNAL_ERROR,
391 "test server accept() failed",
398 std::shared_ptr<BlockingSocket> accept(int timeout=50) {
399 int fd = acceptFD(timeout);
400 return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
403 std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,
405 int fd = acceptFD(timeout);
406 return folly::AsyncSocket::newSocket(evb, fd);
410 * Accept a connection, read data from it, and verify that it matches the
411 * data in the specified buffer.
413 void verifyConnection(const char* buf, size_t len) {
414 // accept a connection
415 std::shared_ptr<BlockingSocket> acceptedSocket = accept();
416 // read the data and compare it to the specified buffer
417 boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
418 acceptedSocket->readAll(readbuf.get(), len);
419 CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
420 // make sure we get EOF next
421 uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
422 CHECK_EQ(bytesRead, 0);
427 folly::SocketAddress address_;