/*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2015-present Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/test/BlockingSocket.h>
+#include <folly/portability/Sockets.h>
#include <boost/scoped_array.hpp>
-#include <poll.h>
-
-// This is a test-only header
-/* using override */
-using namespace folly;
+#include <memory>
enum StateEnum {
STATE_WAITING,
typedef std::function<void()> VoidCallback;
-class ConnCallback : public AsyncSocket::ConnectCallback {
+class ConnCallback : public folly::AsyncSocket::ConnectCallback {
public:
ConnCallback()
- : state(STATE_WAITING)
- , exception(AsyncSocketException::UNKNOWN, "none") {}
+ : state(STATE_WAITING),
+ exception(folly::AsyncSocketException::UNKNOWN, "none") {}
void connectSuccess() noexcept override {
state = STATE_SUCCEEDED;
}
}
- void connectErr(const AsyncSocketException& ex) noexcept override {
+ void connectErr(const folly::AsyncSocketException& ex) noexcept override {
state = STATE_FAILED;
exception = ex;
if (errorCallback) {
}
StateEnum state;
- AsyncSocketException exception;
+ folly::AsyncSocketException exception;
VoidCallback successCallback;
VoidCallback errorCallback;
};
-class BufferCallback : public AsyncTransportWrapper::BufferCallback {
- public:
- BufferCallback()
- : buffered_(false) {}
-
- void onEgressBuffered() override {
- buffered_ = true;
- }
-
- bool hasBuffered() const {
- return buffered_;
- }
-
- private:
- bool buffered_{false};
-};
-
-class WriteCallback : public AsyncTransportWrapper::WriteCallback {
+class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
public:
WriteCallback()
- : state(STATE_WAITING)
- , bytesWritten(0)
- , exception(AsyncSocketException::UNKNOWN, "none") {}
+ : state(STATE_WAITING),
+ bytesWritten(0),
+ exception(folly::AsyncSocketException::UNKNOWN, "none") {}
void writeSuccess() noexcept override {
state = STATE_SUCCEEDED;
}
}
- void writeErr(size_t bytesWritten,
- const AsyncSocketException& ex) noexcept override {
+ void writeErr(size_t nBytesWritten,
+ const folly::AsyncSocketException& ex) noexcept override {
+ LOG(ERROR) << ex.what();
state = STATE_FAILED;
- this->bytesWritten = bytesWritten;
+ this->bytesWritten = nBytesWritten;
exception = ex;
if (errorCallback) {
errorCallback();
}
StateEnum state;
- size_t bytesWritten;
- AsyncSocketException exception;
+ std::atomic<size_t> bytesWritten;
+ folly::AsyncSocketException exception;
VoidCallback successCallback;
VoidCallback errorCallback;
};
-class ReadCallback : public AsyncTransportWrapper::ReadCallback {
+class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
public:
explicit ReadCallback(size_t _maxBufferSz = 4096)
- : state(STATE_WAITING)
- , exception(AsyncSocketException::UNKNOWN, "none")
- , buffers()
- , maxBufferSz(_maxBufferSz) {}
+ : state(STATE_WAITING),
+ exception(folly::AsyncSocketException::UNKNOWN, "none"),
+ buffers(),
+ maxBufferSz(_maxBufferSz) {}
- ~ReadCallback() {
+ ~ReadCallback() override {
for (std::vector<Buffer>::iterator it = buffers.begin();
it != buffers.end();
++it) {
state = STATE_SUCCEEDED;
}
- void readErr(const AsyncSocketException& ex) noexcept override {
+ void readErr(const folly::AsyncSocketException& ex) noexcept override {
state = STATE_FAILED;
exception = ex;
}
buffer = nullptr;
length = 0;
}
- void allocate(size_t length) {
+ void allocate(size_t len) {
assert(buffer == nullptr);
- this->buffer = static_cast<char*>(malloc(length));
- this->length = length;
+ this->buffer = static_cast<char*>(malloc(len));
+ this->length = len;
}
void free() {
::free(buffer);
};
StateEnum state;
- AsyncSocketException exception;
+ folly::AsyncSocketException exception;
std::vector<Buffer> buffers;
Buffer currentBuffer;
VoidCallback dataAvailableCallback;
const size_t maxBufferSz;
};
+class BufferCallback : public folly::AsyncTransport::BufferCallback {
+ public:
+ BufferCallback() : buffered_(false), bufferCleared_(false) {}
+
+ void onEgressBuffered() override { buffered_ = true; }
+
+ void onEgressBufferCleared() override { bufferCleared_ = true; }
+
+ bool hasBuffered() const { return buffered_; }
+
+ bool hasBufferCleared() const { return bufferCleared_; }
+
+ private:
+ bool buffered_{false};
+ bool bufferCleared_{false};
+};
+
class ReadVerifier {
};
+class TestSendMsgParamsCallback :
+ public folly::AsyncSocket::SendMsgParamsCallback {
+ public:
+ TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
+ : flags_(flags),
+ writeFlags_(folly::WriteFlags::NONE),
+ dataSize_(dataSize),
+ data_(data),
+ queriedFlags_(false),
+ queriedData_(false)
+ {}
+
+ void reset(int flags) {
+ flags_ = flags;
+ writeFlags_ = folly::WriteFlags::NONE;
+ queriedFlags_ = false;
+ queriedData_ = false;
+ }
+
+ int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+ override {
+ queriedFlags_ = true;
+ if (writeFlags_ == folly::WriteFlags::NONE) {
+ writeFlags_ = flags;
+ } else {
+ assert(flags == writeFlags_);
+ }
+ return flags_;
+ }
+
+ void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+ queriedData_ = true;
+ if (writeFlags_ == folly::WriteFlags::NONE) {
+ writeFlags_ = flags;
+ } else {
+ assert(flags == writeFlags_);
+ }
+ assert(data != nullptr);
+ memcpy(data, data_, dataSize_);
+ }
+
+ uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+ if (writeFlags_ == folly::WriteFlags::NONE) {
+ writeFlags_ = flags;
+ } else {
+ assert(flags == writeFlags_);
+ }
+ return dataSize_;
+ }
+
+ int flags_;
+ folly::WriteFlags writeFlags_;
+ uint32_t dataSize_;
+ void* data_;
+ bool queriedFlags_;
+ bool queriedData_;
+};
+
class TestServer {
public:
// Create a TestServer.
// This immediately starts listening on an ephemeral port.
- TestServer()
- : fd_(-1) {
- fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
+ explicit TestServer(bool enableTFO = false, int bufSize = -1) : fd_(-1) {
+ namespace fsp = folly::portability::sockets;
+ fd_ = fsp::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
if (fd_ < 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- "failed to create test server socket", errno);
+ throw folly::AsyncSocketException(
+ folly::AsyncSocketException::INTERNAL_ERROR,
+ "failed to create test server socket",
+ errno);
}
if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- "failed to put test server socket in "
- "non-blocking mode", errno);
+ throw folly::AsyncSocketException(
+ folly::AsyncSocketException::INTERNAL_ERROR,
+ "failed to put test server socket in "
+ "non-blocking mode",
+ errno);
+ }
+ if (enableTFO) {
+#if FOLLY_ALLOW_TFO
+ folly::detail::tfo_enable(fd_, 100);
+#endif
}
+
+ struct addrinfo hints, *res;
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = AF_INET;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_flags = AI_PASSIVE;
+
+ if (getaddrinfo(nullptr, "0", &hints, &res)) {
+ throw folly::AsyncSocketException(
+ folly::AsyncSocketException::INTERNAL_ERROR,
+ "Attempted to bind address to socket with "
+ "bad getaddrinfo",
+ errno);
+ }
+
+ SCOPE_EXIT {
+ freeaddrinfo(res);
+ };
+
+ if (bufSize > 0) {
+ setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize));
+ setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize));
+ }
+
+ if (bind(fd_, res->ai_addr, res->ai_addrlen)) {
+ throw folly::AsyncSocketException(
+ folly::AsyncSocketException::INTERNAL_ERROR,
+ "failed to bind to async server socket for port 10",
+ errno);
+ }
+
if (listen(fd_, 10) != 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- "failed to listen on test server socket",
- errno);
+ throw folly::AsyncSocketException(
+ folly::AsyncSocketException::INTERNAL_ERROR,
+ "failed to listen on test server socket",
+ errno);
}
address_.setFromLocalAddress(fd_);
address_.setFromIpPort("127.0.0.1", address_.getPort());
}
+ ~TestServer() {
+ if (fd_ != -1) {
+ close(fd_);
+ }
+ }
+
// Get the address for connecting to the server
const folly::SocketAddress& getAddress() const {
return address_;
}
int acceptFD(int timeout=50) {
+ namespace fsp = folly::portability::sockets;
struct pollfd pfd;
pfd.fd = fd_;
pfd.events = POLLIN;
int ret = poll(&pfd, 1, timeout);
if (ret == 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- "test server accept() timed out");
+ throw folly::AsyncSocketException(
+ folly::AsyncSocketException::INTERNAL_ERROR,
+ "test server accept() timed out");
} else if (ret < 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- "test server accept() poll failed", errno);
+ throw folly::AsyncSocketException(
+ folly::AsyncSocketException::INTERNAL_ERROR,
+ "test server accept() poll failed",
+ errno);
}
- int acceptedFd = ::accept(fd_, nullptr, nullptr);
+ int acceptedFd = fsp::accept(fd_, nullptr, nullptr);
if (acceptedFd < 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- "test server accept() failed", errno);
+ throw folly::AsyncSocketException(
+ folly::AsyncSocketException::INTERNAL_ERROR,
+ "test server accept() failed",
+ errno);
}
return acceptedFd;
std::shared_ptr<BlockingSocket> accept(int timeout=50) {
int fd = acceptFD(timeout);
- return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
+ return std::make_shared<BlockingSocket>(fd);
}
- std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
+ std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,
+ int timeout = 50) {
int fd = acceptFD(timeout);
- return AsyncSocket::newSocket(evb, fd);
+ return folly::AsyncSocket::newSocket(evb, fd);
}
/**