/* register for a read operation (waiting for CLIENT HELLO) */
updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
+
+ if (preReceivedData_) {
+ handleRead();
+ }
}
#if OPENSSL_VERSION_NUMBER >= 0x009080bfL
if (!out) {
return 0;
}
- auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0);
BIO_clear_retry_flags(b);
- if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
- BIO_set_retry_read(b);
+
+ auto appData = OpenSSLUtils::getBioAppData(b);
+ CHECK(appData);
+ auto sslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
+
+ if (sslSock->preReceivedData_ && !sslSock->preReceivedData_->empty()) {
+ VLOG(5) << "AsyncSSLSocket::bioRead() this=" << sslSock
+ << ", reading pre-received data";
+
+ Cursor cursor(sslSock->preReceivedData_.get());
+ auto len = cursor.pullAtMost(out, outl);
+
+ IOBufQueue queue;
+ queue.append(std::move(sslSock->preReceivedData_));
+ queue.trimStart(len);
+ sslSock->preReceivedData_ = queue.move();
+ return len;
+ } else {
+ auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0);
+ if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
+ BIO_set_retry_read(b);
+ }
+ return result;
}
- return result;
}
int AsyncSSLSocket::sslVerifyCallback(
preverifyOk;
}
+void AsyncSSLSocket::setPreReceivedData(std::unique_ptr<IOBuf> data) {
+ CHECK(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED);
+ CHECK(!preReceivedData_);
+ preReceivedData_ = std::move(data);
+}
+
void AsyncSSLSocket::enableClientHelloParsing() {
parseClientHello_ = true;
clientHelloInfo_.reset(new ssl::ClientHelloInfo());
virtual size_t getRawBytesReceived() const override;
void enableClientHelloParsing();
+ void setPreReceivedData(std::unique_ptr<IOBuf> data);
+
/**
* Accept an SSL connection on the socket.
*
std::chrono::steady_clock::time_point handshakeEndTime_;
std::chrono::milliseconds handshakeConnectTimeout_{0};
bool sessionResumptionAttempted_{false};
+
+ std::unique_ptr<IOBuf> preReceivedData_;
};
} // namespace
EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
}
+TEST(AsyncSSLSocketTest, TestPreReceivedData) {
+ EventBase clientEventBase;
+ EventBase serverEventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto dfServerCtx = std::make_shared<SSLContext>();
+ std::array<int, 2> fds;
+ getfds(fds.data());
+ getctx(clientCtx, dfServerCtx);
+
+ AsyncSSLSocket::UniquePtr clientSockPtr(
+ new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSockPtr(
+ new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
+ auto clientSock = clientSockPtr.get();
+ auto serverSock = serverSockPtr.get();
+ SSLHandshakeClient client(std::move(clientSockPtr), true, true);
+
+ // Steal some data from the server.
+ clientEventBase.loopOnce();
+ std::array<uint8_t, 10> buf;
+ recv(fds[1], buf.data(), buf.size(), 0);
+
+ serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
+ SSLHandshakeServer server(std::move(serverSockPtr), true, true);
+ while (!client.handshakeSuccess_ && !client.handshakeError_) {
+ serverEventBase.loopOnce();
+ clientEventBase.loopOnce();
+ }
+
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+ EXPECT_EQ(
+ serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
+}
+
#endif
} // namespace