Add writable() to AsyncTransport
[folly.git] / folly / io / async / test / AsyncSocketTest2.cpp
index bfde20c79ba9270e9156b7e9b793ae6ed7e7c48a..1b48d260697660f07b26c7cf852b8e39ec2cacf9 100644 (file)
@@ -1096,6 +1096,44 @@ TEST(AsyncSocketTest, WritePipeError) {
   ASSERT_FALSE(socket->isClosedByPeer());
 }
 
+/**
+ * Test writing to a socket that has its read side closed
+ */
+TEST(AsyncSocketTest, WriteAfterReadEOF) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket =
+      AsyncSocket::newSocket(&evb, server.getAddress(), 30);
+  evb.loop(); // loop until the socket is connected
+
+  // Accept the connection
+  std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
+  ReadCallback rcb;
+  acceptedSocket->setReadCB(&rcb);
+
+  // Shutdown the write side of client socket (read side of server socket)
+  socket->shutdownWrite();
+  evb.loop();
+
+  // Check that accepted socket is still writable
+  ASSERT_FALSE(acceptedSocket->good());
+  ASSERT_TRUE(acceptedSocket->writable());
+
+  // Write data to accepted socket
+  constexpr size_t simpleBufLength = 5;
+  char simpleBuf[simpleBufLength];
+  memset(simpleBuf, 'a', simpleBufLength);
+  WriteCallback wcb;
+  acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
+  evb.loop();
+
+  // Make sure we were able to write even after getting a read EOF
+  ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
+  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+}
+
 /**
  * Test that bytes written is correctly computed in case of write failure
  */
@@ -2797,6 +2835,8 @@ TEST(AsyncSocketTest, ConnectTFOWithBigData) {
   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
 }
 
+#endif // FOLLY_ALLOW_TFO
+
 class MockEvbChangeCallback : public AsyncSocket::EvbChangeCallback {
  public:
   MOCK_METHOD1(evbAttached, void(AsyncSocket*));
@@ -2804,7 +2844,7 @@ class MockEvbChangeCallback : public AsyncSocket::EvbChangeCallback {
 };
 
 TEST(AsyncSocketTest, EvbCallbacks) {
-  auto cb = folly::make_unique<MockEvbChangeCallback>();
+  auto cb = std::make_unique<MockEvbChangeCallback>();
   EventBase evb;
   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
 
@@ -2821,21 +2861,11 @@ TEST(AsyncSocketTest, EvbCallbacks) {
 /* copied from include/uapi/linux/net_tstamp.h */
 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
 enum SOF_TIMESTAMPING {
-  // SOF_TIMESTAMPING_TX_HARDWARE = (1 << 0),
-  // SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
-  // SOF_TIMESTAMPING_RX_HARDWARE = (1 << 2),
-  // SOF_TIMESTAMPING_RX_SOFTWARE = (1 << 3),
   SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
-  // SOF_TIMESTAMPING_SYS_HARDWARE = (1 << 5),
-  // SOF_TIMESTAMPING_RAW_HARDWARE = (1 << 6),
   SOF_TIMESTAMPING_OPT_ID = (1 << 7),
   SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
-  // SOF_TIMESTAMPING_TX_ACK = (1 << 9),
   SOF_TIMESTAMPING_OPT_CMSG = (1 << 10),
   SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
-
-  // SOF_TIMESTAMPING_LAST = SOF_TIMESTAMPING_OPT_TSONLY,
-  // SOF_TIMESTAMPING_MASK = (SOF_TIMESTAMPING_LAST - 1) | SOF_TIMESTAMPING_LAST,
 };
 TEST(AsyncSocketTest, ErrMessageCallback) {
   TestServer server;
@@ -2908,4 +2938,285 @@ TEST(AsyncSocketTest, ErrMessageCallback) {
 }
 #endif // MSG_ERRQUEUE
 
-#endif
+TEST(AsyncSocket, PreReceivedData) {
+  TestServer server;
+
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->connect(nullptr, server.getAddress(), 30);
+  evb.loop();
+
+  socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+  auto acceptedSocket = server.acceptAsync(&evb);
+
+  ReadCallback peekCallback(2);
+  ReadCallback readCallback;
+  peekCallback.dataAvailableCallback = [&]() {
+    peekCallback.verifyData("he", 2);
+    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("h"));
+    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("e"));
+    acceptedSocket->setReadCB(nullptr);
+    acceptedSocket->setReadCB(&readCallback);
+  };
+  readCallback.dataAvailableCallback = [&]() {
+    if (readCallback.dataRead() == 5) {
+      readCallback.verifyData("hello", 5);
+      acceptedSocket->setReadCB(nullptr);
+    }
+  };
+
+  acceptedSocket->setReadCB(&peekCallback);
+
+  evb.loop();
+}
+
+TEST(AsyncSocket, PreReceivedDataOnly) {
+  TestServer server;
+
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->connect(nullptr, server.getAddress(), 30);
+  evb.loop();
+
+  socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+  auto acceptedSocket = server.acceptAsync(&evb);
+
+  ReadCallback peekCallback;
+  ReadCallback readCallback;
+  peekCallback.dataAvailableCallback = [&]() {
+    peekCallback.verifyData("hello", 5);
+    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
+    acceptedSocket->setReadCB(&readCallback);
+  };
+  readCallback.dataAvailableCallback = [&]() {
+    readCallback.verifyData("hello", 5);
+    acceptedSocket->setReadCB(nullptr);
+  };
+
+  acceptedSocket->setReadCB(&peekCallback);
+
+  evb.loop();
+}
+
+TEST(AsyncSocket, PreReceivedDataPartial) {
+  TestServer server;
+
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->connect(nullptr, server.getAddress(), 30);
+  evb.loop();
+
+  socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+  auto acceptedSocket = server.acceptAsync(&evb);
+
+  ReadCallback peekCallback;
+  ReadCallback smallReadCallback(3);
+  ReadCallback normalReadCallback;
+  peekCallback.dataAvailableCallback = [&]() {
+    peekCallback.verifyData("hello", 5);
+    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
+    acceptedSocket->setReadCB(&smallReadCallback);
+  };
+  smallReadCallback.dataAvailableCallback = [&]() {
+    smallReadCallback.verifyData("hel", 3);
+    acceptedSocket->setReadCB(&normalReadCallback);
+  };
+  normalReadCallback.dataAvailableCallback = [&]() {
+    normalReadCallback.verifyData("lo", 2);
+    acceptedSocket->setReadCB(nullptr);
+  };
+
+  acceptedSocket->setReadCB(&peekCallback);
+
+  evb.loop();
+}
+
+TEST(AsyncSocket, PreReceivedDataTakeover) {
+  TestServer server;
+
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->connect(nullptr, server.getAddress(), 30);
+  evb.loop();
+
+  socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+  auto acceptedSocket =
+      AsyncSocket::UniquePtr(new AsyncSocket(&evb, server.acceptFD()));
+  AsyncSocket::UniquePtr takeoverSocket;
+
+  ReadCallback peekCallback(3);
+  ReadCallback readCallback;
+  peekCallback.dataAvailableCallback = [&]() {
+    peekCallback.verifyData("hel", 3);
+    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
+    acceptedSocket->setReadCB(nullptr);
+    takeoverSocket =
+        AsyncSocket::UniquePtr(new AsyncSocket(std::move(acceptedSocket)));
+    takeoverSocket->setReadCB(&readCallback);
+  };
+  readCallback.dataAvailableCallback = [&]() {
+    readCallback.verifyData("hello", 5);
+    takeoverSocket->setReadCB(nullptr);
+  };
+
+  acceptedSocket->setReadCB(&peekCallback);
+
+  evb.loop();
+}
+
+TEST(AsyncSocketTest, SendMessageFlags) {
+  TestServer server;
+  TestSendMsgParamsCallback sendMsgCB(
+      MSG_DONTWAIT|MSG_NOSIGNAL|MSG_MORE, 0, nullptr);
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+
+  evb.loop();
+  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
+
+  // Set SendMsgParamsCallback
+  socket->setSendMsgParamCB(&sendMsgCB);
+  ASSERT_EQ(socket->getSendMsgParamsCB(), &sendMsgCB);
+
+  // Write the first portion of data. This data is expected to be
+  // sent out immediately.
+  std::vector<uint8_t> buf(128, 'a');
+  WriteCallback wcb;
+  sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL);
+  socket->write(&wcb, buf.data(), buf.size());
+  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+  ASSERT_TRUE(sendMsgCB.queriedFlags_);
+  ASSERT_FALSE(sendMsgCB.queriedData_);
+
+  // Using different flags for the second write operation.
+  // MSG_MORE flag is expected to delay sending this
+  // data to the wire.
+  sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE);
+  socket->write(&wcb, buf.data(), buf.size());
+  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+  ASSERT_TRUE(sendMsgCB.queriedFlags_);
+  ASSERT_FALSE(sendMsgCB.queriedData_);
+
+  // Make sure the accepted socket saw only the data from
+  // the first write request.
+  std::vector<uint8_t> readbuf(2 * buf.size());
+  uint32_t bytesRead = acceptedSocket->read(readbuf.data(), readbuf.size());
+  ASSERT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
+  ASSERT_EQ(bytesRead, buf.size());
+
+  // Make sure the server got a connection and received the data
+  acceptedSocket->close();
+  socket->close();
+
+  ASSERT_TRUE(socket->isClosedBySelf());
+  ASSERT_FALSE(socket->isClosedByPeer());
+}
+
+TEST(AsyncSocketTest, SendMessageAncillaryData) {
+  int fds[2];
+  EXPECT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0);
+
+  // "Client" socket
+  int cfd = fds[0];
+  ASSERT_NE(cfd, -1);
+
+  // "Server" socket
+  int sfd = fds[1];
+  ASSERT_NE(sfd, -1);
+  SCOPE_EXIT { close(sfd); };
+
+  // Instantiate AsyncSocket object for the connected socket
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, cfd);
+
+  // Open a temporary file and write a magic string to it
+  // We'll transfer the file handle to test the message parameters
+  // callback logic.
+  TemporaryFile file(StringPiece(),
+                     fs::path(),
+                     TemporaryFile::Scope::UNLINK_IMMEDIATELY);
+  int tmpfd = file.fd();
+  ASSERT_NE(tmpfd, -1) << "Failed to open a temporary file";
+  std::string magicString("Magic string");
+  ASSERT_EQ(write(tmpfd, magicString.c_str(), magicString.length()),
+            magicString.length());
+
+  // Send message
+  union {
+    // Space large enough to hold an 'int'
+    char control[CMSG_SPACE(sizeof(int))];
+    struct cmsghdr cmh;
+  } s_u;
+  s_u.cmh.cmsg_len = CMSG_LEN(sizeof(int));
+  s_u.cmh.cmsg_level = SOL_SOCKET;
+  s_u.cmh.cmsg_type = SCM_RIGHTS;
+  memcpy(CMSG_DATA(&s_u.cmh), &tmpfd, sizeof(int));
+
+  // Set up the callback providing message parameters
+  TestSendMsgParamsCallback sendMsgCB(
+      MSG_DONTWAIT | MSG_NOSIGNAL, sizeof(s_u.control), s_u.control);
+  socket->setSendMsgParamCB(&sendMsgCB);
+
+  // We must transmit at least 1 byte of real data in order
+  // to send ancillary data
+  int s_data = 12345;
+  WriteCallback wcb;
+  socket->write(&wcb, &s_data, sizeof(s_data));
+  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+
+  // Receive the message
+  union {
+    // Space large enough to hold an 'int'
+    char control[CMSG_SPACE(sizeof(int))];
+    struct cmsghdr cmh;
+  } r_u;
+  struct msghdr msgh;
+  struct iovec iov;
+  int r_data = 0;
+
+  msgh.msg_control = r_u.control;
+  msgh.msg_controllen = sizeof(r_u.control);
+  msgh.msg_name = nullptr;
+  msgh.msg_namelen = 0;
+  msgh.msg_iov = &iov;
+  msgh.msg_iovlen = 1;
+  iov.iov_base = &r_data;
+  iov.iov_len = sizeof(r_data);
+
+  // Receive data
+  ASSERT_NE(recvmsg(sfd, &msgh, 0), -1) << "recvmsg failed: " << errno;
+
+  // Validate the received message
+  ASSERT_EQ(r_u.cmh.cmsg_len, CMSG_LEN(sizeof(int)));
+  ASSERT_EQ(r_u.cmh.cmsg_level, SOL_SOCKET);
+  ASSERT_EQ(r_u.cmh.cmsg_type, SCM_RIGHTS);
+  ASSERT_EQ(r_data, s_data);
+  int fd = 0;
+  memcpy(&fd, CMSG_DATA(&r_u.cmh), sizeof(int));
+  ASSERT_NE(fd, 0);
+  SCOPE_EXIT { close(fd); };
+
+  std::vector<uint8_t> transferredMagicString(magicString.length() + 1, 0);
+
+  // Reposition to the beginning of the file
+  ASSERT_EQ(0, lseek(fd, 0, SEEK_SET));
+
+  // Read the magic string back, and compare it with the original
+  ASSERT_EQ(
+      magicString.length(),
+      read(fd, transferredMagicString.data(), transferredMagicString.size()));
+  ASSERT_TRUE(std::equal(
+      magicString.begin(),
+      magicString.end(),
+      transferredMagicString.begin()));
+}