+
+class ReadCallbackTerminator : public ReadCallback {
+ public:
+ ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
+ : ReadCallback(wcb)
+ , base_(base) {}
+
+ // Do not write data back, terminate the loop.
+ void readDataAvailable(size_t len) noexcept override {
+ std::cerr << "readDataAvailable, len " << len << std::endl;
+
+ currentBuffer.length = len;
+
+ buffers.push_back(currentBuffer);
+ currentBuffer.reset();
+ state = STATE_SUCCEEDED;
+
+ socket_->setReadCB(nullptr);
+ base_->terminateLoopSoon();
+ }
+ private:
+ EventBase* base_;
+};
+
+
+/**
+ * Test a full unencrypted codepath
+ */
+TEST(AsyncSSLSocketTest, UnencryptedTest) {
+ EventBase base;
+
+ auto clientCtx = std::make_shared<folly::SSLContext>();
+ auto serverCtx = std::make_shared<folly::SSLContext>();
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, serverCtx);
+ auto client = AsyncSSLSocket::newSocket(
+ clientCtx, &base, fds[0], false, true);
+ auto server = AsyncSSLSocket::newSocket(
+ serverCtx, &base, fds[1], true, true);
+
+ ReadCallbackTerminator readCallback(&base, nullptr);
+ server->setReadCB(&readCallback);
+ readCallback.setSocket(server);
+
+ uint8_t buf[128];
+ memset(buf, 'a', sizeof(buf));
+ client->write(nullptr, buf, sizeof(buf));
+
+ // Check that bytes are unencrypted
+ char c;
+ EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
+ EXPECT_EQ('a', c);
+
+ EventBaseAborter eba(&base, 3000);
+ base.loop();
+
+ EXPECT_EQ(1, readCallback.buffers.size());
+ EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
+
+ server->setReadCB(&readCallback);
+
+ // Unencrypted
+ server->sslAccept(nullptr);
+ client->sslConn(nullptr);
+
+ // Do NOT wait for handshake, writing should be queued and happen after
+
+ client->write(nullptr, buf, sizeof(buf));
+
+ // Check that bytes are *not* unencrypted
+ char c2;
+ EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
+ EXPECT_NE('a', c2);
+
+
+ base.loop();
+
+ EXPECT_EQ(2, readCallback.buffers.size());
+ EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
+}
+
+TEST(AsyncSSLSocketTest, ConnResetErrorString) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ WriteErrorCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback,
+ HandshakeCallback::EXPECT_ERROR);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+ socket->open();
+ uint8_t buf[3] = {0x16, 0x03, 0x01};
+ socket->write(buf, sizeof(buf));
+ socket->closeWithReset();
+
+ handshakeCallback.waitForHandshake();
+ EXPECT_NE(
+ handshakeCallback.errorString_.find("Network error"), std::string::npos);
+ EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
+}
+
+TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ WriteErrorCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback,
+ HandshakeCallback::EXPECT_ERROR);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+ socket->open();
+ uint8_t buf[3] = {0x16, 0x03, 0x01};
+ socket->write(buf, sizeof(buf));
+ socket->close();
+
+ handshakeCallback.waitForHandshake();
+ EXPECT_NE(
+ handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
+ EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
+}
+
+TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ WriteErrorCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback,
+ HandshakeCallback::EXPECT_ERROR);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+ socket->open();
+ uint8_t buf[256] = {0x16, 0x03};
+ memset(buf + 2, 'a', sizeof(buf) - 2);
+ socket->write(buf, sizeof(buf));
+ socket->close();
+
+ handshakeCallback.waitForHandshake();
+ EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
+ std::string::npos);
+ EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
+ std::string::npos);
+}
+
+#if FOLLY_ALLOW_TFO
+
+class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
+ public:
+ using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
+
+ explicit MockAsyncTFOSSLSocket(
+ std::shared_ptr<folly::SSLContext> sslCtx,
+ EventBase* evb)
+ : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
+
+ MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
+};
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, true);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->enableTFO();
+ socket->open();
+
+ // write()
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+ socket->write(buf.data(), buf.size());
+
+ // read()
+ std::array<uint8_t, 128> readbuf;
+ uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, 128);
+ EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+ // close()
+ socket->close();
+}
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, false);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->enableTFO();
+ socket->open();
+
+ // write()
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+ socket->write(buf.data(), buf.size());
+
+ // read()
+ std::array<uint8_t, 128> readbuf;
+ uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, 128);
+ EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+ // close()
+ socket->close();