+TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
+ // Start listening on a local port
+ EventBase evb;
+
+ // Hopefully nothing is listening on this address
+ SocketAddress addr("127.0.0.1", 65535);
+ auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
+ ConnCallback ccb;
+ socket->connect(&ccb, addr, 100);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+ EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
+}
+
+TEST(AsyncSSLSocketTest, TestPreReceivedData) {
+ EventBase clientEventBase;
+ EventBase serverEventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto dfServerCtx = std::make_shared<SSLContext>();
+ std::array<int, 2> fds;
+ getfds(fds.data());
+ getctx(clientCtx, dfServerCtx);
+
+ AsyncSSLSocket::UniquePtr clientSockPtr(
+ new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSockPtr(
+ new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
+ auto clientSock = clientSockPtr.get();
+ auto serverSock = serverSockPtr.get();
+ SSLHandshakeClient client(std::move(clientSockPtr), true, true);
+
+ // Steal some data from the server.
+ clientEventBase.loopOnce();
+ std::array<uint8_t, 10> buf;
+ recv(fds[1], buf.data(), buf.size(), 0);
+
+ serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
+ SSLHandshakeServer server(std::move(serverSockPtr), true, true);
+ while (!client.handshakeSuccess_ && !client.handshakeError_) {
+ serverEventBase.loopOnce();
+ clientEventBase.loopOnce();
+ }
+
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+ EXPECT_EQ(
+ serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
+}
+
+TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) {
+ EventBase clientEventBase;
+ EventBase serverEventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto dfServerCtx = std::make_shared<SSLContext>();
+ std::array<int, 2> fds;
+ getfds(fds.data());
+ getctx(clientCtx, dfServerCtx);
+
+ AsyncSSLSocket::UniquePtr clientSockPtr(
+ new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
+ AsyncSocket::UniquePtr serverSockPtr(
+ new AsyncSocket(&serverEventBase, fds[1]));
+ auto clientSock = clientSockPtr.get();
+ auto serverSock = serverSockPtr.get();
+ SSLHandshakeClient client(std::move(clientSockPtr), true, true);
+
+ // Steal some data from the server.
+ clientEventBase.loopOnce();
+ std::array<uint8_t, 10> buf;
+ recv(fds[1], buf.data(), buf.size(), 0);
+ serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
+ AsyncSSLSocket::UniquePtr serverSSLSockPtr(
+ new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true));
+ auto serverSSLSock = serverSSLSockPtr.get();
+ SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true);
+ while (!client.handshakeSuccess_ && !client.handshakeError_) {
+ serverEventBase.loopOnce();
+ clientEventBase.loopOnce();
+ }
+
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+ EXPECT_EQ(
+ serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
+}
+
+/**
+ * Test overriding the flags passed to "sendmsg()" system call,
+ * and verifying that write requests fail properly.
+ */
+TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
+ // Start listening on a local port
+ SendMsgFlagsCallback msgCallback;
+ ExpectWriteErrorCallback writeCallback(&msgCallback);
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+ sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->open();
+
+ // Setting flags to "-1" to trigger "Invalid argument" error
+ // on attempt to use this flags in sendmsg() system call.
+ msgCallback.resetFlags(-1);
+
+ // write()
+ std::vector<uint8_t> buf(128, 'a');
+ ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
+
+ // close()
+ socket->close();
+
+ cerr << "SendMsgParamsCallback test completed" << endl;
+}
+
+#ifdef FOLLY_HAVE_MSG_ERRQUEUE
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server.
+ */
+TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
+ // This test requires Linux kernel v4.6 or later
+ struct utsname s_uname;
+ memset(&s_uname, 0, sizeof(s_uname));
+ ASSERT_EQ(uname(&s_uname), 0);
+ int major, minor;
+ folly::StringPiece extra;
+ if (folly::split<false>(
+ '.', std::string(s_uname.release) + ".", major, minor, extra)) {
+ if (major < 4 || (major == 4 && minor < 6)) {
+ LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
+ << "kernel ver. " << s_uname.release << " detected).";
+ return;
+ }
+ }
+
+ // Start listening on a local port
+ SendMsgDataCallback msgCallback;
+ WriteCheckTimestampCallback writeCallback(&msgCallback);
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+ sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->open();
+
+ // Adding MSG_EOR flag to the message flags - it'll trigger
+ // timestamp generation for the last byte of the message.
+ msgCallback.resetFlags(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_EOR);
+
+ // Init ancillary data buffer to trigger timestamp notification
+ union {
+ uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
+ struct cmsghdr cmsg;
+ } u;
+ u.cmsg.cmsg_level = SOL_SOCKET;
+ u.cmsg.cmsg_type = SO_TIMESTAMPING;
+ u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
+ uint32_t flags = SOF_TIMESTAMPING_TX_SCHED | SOF_TIMESTAMPING_TX_SOFTWARE |
+ SOF_TIMESTAMPING_TX_ACK;
+ memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
+ std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
+ memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
+ msgCallback.resetData(std::move(ctrl));
+
+ // write()
+ std::vector<uint8_t> buf(128, 'a');
+ socket->write(buf.data(), buf.size());
+
+ // read()
+ std::vector<uint8_t> readbuf(buf.size());
+ uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, buf.size());
+ EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
+
+ writeCallback.checkForTimestampNotifications();
+
+ // close()
+ socket->close();
+
+ cerr << "SendMsgDataCallback test completed" << endl;
+}
+#endif // FOLLY_HAVE_MSG_ERRQUEUE
+