ASSERT_FALSE(socket->isClosedByPeer());
}
+/**
+ * Test writing to a socket that has its read side closed
+ */
+TEST(AsyncSocketTest, WriteAfterReadEOF) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket =
+ AsyncSocket::newSocket(&evb, server.getAddress(), 30);
+ evb.loop(); // loop until the socket is connected
+
+ // Accept the connection
+ std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
+ ReadCallback rcb;
+ acceptedSocket->setReadCB(&rcb);
+
+ // Shutdown the write side of client socket (read side of server socket)
+ socket->shutdownWrite();
+ evb.loop();
+
+ // Check that accepted socket is still writable
+ ASSERT_FALSE(acceptedSocket->good());
+ ASSERT_TRUE(acceptedSocket->writable());
+
+ // Write data to accepted socket
+ constexpr size_t simpleBufLength = 5;
+ char simpleBuf[simpleBufLength];
+ memset(simpleBuf, 'a', simpleBufLength);
+ WriteCallback wcb;
+ acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
+ evb.loop();
+
+ // Make sure we were able to write even after getting a read EOF
+ ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
+ ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+}
+
/**
* Test that bytes written is correctly computed in case of write failure
*/
EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
}
+#endif // FOLLY_ALLOW_TFO
+
class MockEvbChangeCallback : public AsyncSocket::EvbChangeCallback {
public:
MOCK_METHOD1(evbAttached, void(AsyncSocket*));
};
TEST(AsyncSocketTest, EvbCallbacks) {
- auto cb = folly::make_unique<MockEvbChangeCallback>();
+ auto cb = std::make_unique<MockEvbChangeCallback>();
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
/* copied from include/uapi/linux/net_tstamp.h */
/* SO_TIMESTAMPING gets an integer bit field comprised of these values */
enum SOF_TIMESTAMPING {
- // SOF_TIMESTAMPING_TX_HARDWARE = (1 << 0),
- // SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
- // SOF_TIMESTAMPING_RX_HARDWARE = (1 << 2),
- // SOF_TIMESTAMPING_RX_SOFTWARE = (1 << 3),
SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
- // SOF_TIMESTAMPING_SYS_HARDWARE = (1 << 5),
- // SOF_TIMESTAMPING_RAW_HARDWARE = (1 << 6),
SOF_TIMESTAMPING_OPT_ID = (1 << 7),
SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
- // SOF_TIMESTAMPING_TX_ACK = (1 << 9),
SOF_TIMESTAMPING_OPT_CMSG = (1 << 10),
SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
-
- // SOF_TIMESTAMPING_LAST = SOF_TIMESTAMPING_OPT_TSONLY,
- // SOF_TIMESTAMPING_MASK = (SOF_TIMESTAMPING_LAST - 1) | SOF_TIMESTAMPING_LAST,
};
TEST(AsyncSocketTest, ErrMessageCallback) {
TestServer server;
}
#endif // MSG_ERRQUEUE
-#endif
+TEST(AsyncSocket, PreReceivedData) {
+ TestServer server;
+
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ socket->connect(nullptr, server.getAddress(), 30);
+ evb.loop();
+
+ socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+ auto acceptedSocket = server.acceptAsync(&evb);
+
+ ReadCallback peekCallback(2);
+ ReadCallback readCallback;
+ peekCallback.dataAvailableCallback = [&]() {
+ peekCallback.verifyData("he", 2);
+ acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("h"));
+ acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("e"));
+ acceptedSocket->setReadCB(nullptr);
+ acceptedSocket->setReadCB(&readCallback);
+ };
+ readCallback.dataAvailableCallback = [&]() {
+ if (readCallback.dataRead() == 5) {
+ readCallback.verifyData("hello", 5);
+ acceptedSocket->setReadCB(nullptr);
+ }
+ };
+
+ acceptedSocket->setReadCB(&peekCallback);
+
+ evb.loop();
+}
+
+TEST(AsyncSocket, PreReceivedDataOnly) {
+ TestServer server;
+
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ socket->connect(nullptr, server.getAddress(), 30);
+ evb.loop();
+
+ socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+ auto acceptedSocket = server.acceptAsync(&evb);
+
+ ReadCallback peekCallback;
+ ReadCallback readCallback;
+ peekCallback.dataAvailableCallback = [&]() {
+ peekCallback.verifyData("hello", 5);
+ acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
+ acceptedSocket->setReadCB(&readCallback);
+ };
+ readCallback.dataAvailableCallback = [&]() {
+ readCallback.verifyData("hello", 5);
+ acceptedSocket->setReadCB(nullptr);
+ };
+
+ acceptedSocket->setReadCB(&peekCallback);
+
+ evb.loop();
+}
+
+TEST(AsyncSocket, PreReceivedDataPartial) {
+ TestServer server;
+
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ socket->connect(nullptr, server.getAddress(), 30);
+ evb.loop();
+
+ socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+ auto acceptedSocket = server.acceptAsync(&evb);
+
+ ReadCallback peekCallback;
+ ReadCallback smallReadCallback(3);
+ ReadCallback normalReadCallback;
+ peekCallback.dataAvailableCallback = [&]() {
+ peekCallback.verifyData("hello", 5);
+ acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
+ acceptedSocket->setReadCB(&smallReadCallback);
+ };
+ smallReadCallback.dataAvailableCallback = [&]() {
+ smallReadCallback.verifyData("hel", 3);
+ acceptedSocket->setReadCB(&normalReadCallback);
+ };
+ normalReadCallback.dataAvailableCallback = [&]() {
+ normalReadCallback.verifyData("lo", 2);
+ acceptedSocket->setReadCB(nullptr);
+ };
+
+ acceptedSocket->setReadCB(&peekCallback);
+
+ evb.loop();
+}
+
+TEST(AsyncSocket, PreReceivedDataTakeover) {
+ TestServer server;
+
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ socket->connect(nullptr, server.getAddress(), 30);
+ evb.loop();
+
+ socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+ auto acceptedSocket =
+ AsyncSocket::UniquePtr(new AsyncSocket(&evb, server.acceptFD()));
+ AsyncSocket::UniquePtr takeoverSocket;
+
+ ReadCallback peekCallback(3);
+ ReadCallback readCallback;
+ peekCallback.dataAvailableCallback = [&]() {
+ peekCallback.verifyData("hel", 3);
+ acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
+ acceptedSocket->setReadCB(nullptr);
+ takeoverSocket =
+ AsyncSocket::UniquePtr(new AsyncSocket(std::move(acceptedSocket)));
+ takeoverSocket->setReadCB(&readCallback);
+ };
+ readCallback.dataAvailableCallback = [&]() {
+ readCallback.verifyData("hello", 5);
+ takeoverSocket->setReadCB(nullptr);
+ };
+
+ acceptedSocket->setReadCB(&peekCallback);
+
+ evb.loop();
+}
+
+TEST(AsyncSocketTest, SendMessageFlags) {
+ TestServer server;
+ TestSendMsgParamsCallback sendMsgCB(
+ MSG_DONTWAIT|MSG_NOSIGNAL|MSG_MORE, 0, nullptr);
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+ std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+
+ evb.loop();
+ ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
+
+ // Set SendMsgParamsCallback
+ socket->setSendMsgParamCB(&sendMsgCB);
+ ASSERT_EQ(socket->getSendMsgParamsCB(), &sendMsgCB);
+
+ // Write the first portion of data. This data is expected to be
+ // sent out immediately.
+ std::vector<uint8_t> buf(128, 'a');
+ WriteCallback wcb;
+ sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL);
+ socket->write(&wcb, buf.data(), buf.size());
+ ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+ ASSERT_TRUE(sendMsgCB.queriedFlags_);
+ ASSERT_FALSE(sendMsgCB.queriedData_);
+
+ // Using different flags for the second write operation.
+ // MSG_MORE flag is expected to delay sending this
+ // data to the wire.
+ sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE);
+ socket->write(&wcb, buf.data(), buf.size());
+ ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+ ASSERT_TRUE(sendMsgCB.queriedFlags_);
+ ASSERT_FALSE(sendMsgCB.queriedData_);
+
+ // Make sure the accepted socket saw only the data from
+ // the first write request.
+ std::vector<uint8_t> readbuf(2 * buf.size());
+ uint32_t bytesRead = acceptedSocket->read(readbuf.data(), readbuf.size());
+ ASSERT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
+ ASSERT_EQ(bytesRead, buf.size());
+
+ // Make sure the server got a connection and received the data
+ acceptedSocket->close();
+ socket->close();
+
+ ASSERT_TRUE(socket->isClosedBySelf());
+ ASSERT_FALSE(socket->isClosedByPeer());
+}
+
+TEST(AsyncSocketTest, SendMessageAncillaryData) {
+ int fds[2];
+ EXPECT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0);
+
+ // "Client" socket
+ int cfd = fds[0];
+ ASSERT_NE(cfd, -1);
+
+ // "Server" socket
+ int sfd = fds[1];
+ ASSERT_NE(sfd, -1);
+ SCOPE_EXIT { close(sfd); };
+
+ // Instantiate AsyncSocket object for the connected socket
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, cfd);
+
+ // Open a temporary file and write a magic string to it
+ // We'll transfer the file handle to test the message parameters
+ // callback logic.
+ TemporaryFile file(StringPiece(),
+ fs::path(),
+ TemporaryFile::Scope::UNLINK_IMMEDIATELY);
+ int tmpfd = file.fd();
+ ASSERT_NE(tmpfd, -1) << "Failed to open a temporary file";
+ std::string magicString("Magic string");
+ ASSERT_EQ(write(tmpfd, magicString.c_str(), magicString.length()),
+ magicString.length());
+
+ // Send message
+ union {
+ // Space large enough to hold an 'int'
+ char control[CMSG_SPACE(sizeof(int))];
+ struct cmsghdr cmh;
+ } s_u;
+ s_u.cmh.cmsg_len = CMSG_LEN(sizeof(int));
+ s_u.cmh.cmsg_level = SOL_SOCKET;
+ s_u.cmh.cmsg_type = SCM_RIGHTS;
+ memcpy(CMSG_DATA(&s_u.cmh), &tmpfd, sizeof(int));
+
+ // Set up the callback providing message parameters
+ TestSendMsgParamsCallback sendMsgCB(
+ MSG_DONTWAIT | MSG_NOSIGNAL, sizeof(s_u.control), s_u.control);
+ socket->setSendMsgParamCB(&sendMsgCB);
+
+ // We must transmit at least 1 byte of real data in order
+ // to send ancillary data
+ int s_data = 12345;
+ WriteCallback wcb;
+ socket->write(&wcb, &s_data, sizeof(s_data));
+ ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+
+ // Receive the message
+ union {
+ // Space large enough to hold an 'int'
+ char control[CMSG_SPACE(sizeof(int))];
+ struct cmsghdr cmh;
+ } r_u;
+ struct msghdr msgh;
+ struct iovec iov;
+ int r_data = 0;
+
+ msgh.msg_control = r_u.control;
+ msgh.msg_controllen = sizeof(r_u.control);
+ msgh.msg_name = nullptr;
+ msgh.msg_namelen = 0;
+ msgh.msg_iov = &iov;
+ msgh.msg_iovlen = 1;
+ iov.iov_base = &r_data;
+ iov.iov_len = sizeof(r_data);
+
+ // Receive data
+ ASSERT_NE(recvmsg(sfd, &msgh, 0), -1) << "recvmsg failed: " << errno;
+
+ // Validate the received message
+ ASSERT_EQ(r_u.cmh.cmsg_len, CMSG_LEN(sizeof(int)));
+ ASSERT_EQ(r_u.cmh.cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(r_u.cmh.cmsg_type, SCM_RIGHTS);
+ ASSERT_EQ(r_data, s_data);
+ int fd = 0;
+ memcpy(&fd, CMSG_DATA(&r_u.cmh), sizeof(int));
+ ASSERT_NE(fd, 0);
+ SCOPE_EXIT { close(fd); };
+
+ std::vector<uint8_t> transferredMagicString(magicString.length() + 1, 0);
+
+ // Reposition to the beginning of the file
+ ASSERT_EQ(0, lseek(fd, 0, SEEK_SET));
+
+ // Read the magic string back, and compare it with the original
+ ASSERT_EQ(
+ magicString.length(),
+ read(fd, transferredMagicString.data(), transferredMagicString.size()));
+ ASSERT_TRUE(std::equal(
+ magicString.begin(),
+ magicString.end(),
+ transferredMagicString.begin()));
+}