// are responsible for setting the succeeded state properly before the
// destructors are called.
+class SendMsgParamsCallbackBase :
+ public folly::AsyncSocket::SendMsgParamsCallback {
+ public:
+ SendMsgParamsCallbackBase() {}
+
+ void setSocket(
+ const std::shared_ptr<AsyncSSLSocket> &socket) {
+ socket_ = socket;
+ oldCallback_ = socket_->getSendMsgParamsCB();
+ socket_->setSendMsgParamCB(this);
+ }
+
+ int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+ override {
+ return oldCallback_->getFlags(flags);
+ }
+
+ void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+ oldCallback_->getAncillaryData(flags, data);
+ }
+
+ uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+ return oldCallback_->getAncillaryDataSize(flags);
+ }
+
+ std::shared_ptr<AsyncSSLSocket> socket_;
+ folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
+};
+
+class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
+ public:
+ SendMsgFlagsCallback() {}
+
+ void resetFlags(int flags) {
+ flags_ = flags;
+ }
+
+ int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+ override {
+ if (flags_) {
+ return flags_;
+ } else {
+ return oldCallback_->getFlags(flags);
+ }
+ }
+
+ int flags_{0};
+};
+
+class SendMsgDataCallback : public SendMsgFlagsCallback {
+ public:
+ SendMsgDataCallback() {}
+
+ void resetData(std::vector<char>&& data) {
+ ancillaryData_.swap(data);
+ }
+
+ void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+ if (ancillaryData_.size()) {
+ std::cerr << "getAncillaryData: copying data" << std::endl;
+ memcpy(data, ancillaryData_.data(), ancillaryData_.size());
+ } else {
+ oldCallback_->getAncillaryData(flags, data);
+ }
+ }
+
+ uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+ if (ancillaryData_.size()) {
+ std::cerr << "getAncillaryDataSize: returning size" << std::endl;
+ return ancillaryData_.size();
+ } else {
+ return oldCallback_->getAncillaryDataSize(flags);
+ }
+ }
+
+ std::vector<char> ancillaryData_;
+};
+
class WriteCallbackBase :
public AsyncTransportWrapper::WriteCallback {
public:
- WriteCallbackBase()
+ explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
: state(STATE_WAITING)
, bytesWritten(0)
- , exception(AsyncSocketException::UNKNOWN, "none") {}
+ , exception(AsyncSocketException::UNKNOWN, "none")
+ , mcb_(mcb) {}
~WriteCallbackBase() {
EXPECT_EQ(STATE_SUCCEEDED, state);
}
- void setSocket(
+ virtual void setSocket(
const std::shared_ptr<AsyncSSLSocket> &socket) {
socket_ = socket;
+ if (mcb_) {
+ mcb_->setSocket(socket);
+ }
}
- void writeSuccess() noexcept override {
+ virtual void writeSuccess() noexcept override {
std::cerr << "writeSuccess" << std::endl;
state = STATE_SUCCEEDED;
}
StateEnum state;
size_t bytesWritten;
AsyncSocketException exception;
+ SendMsgParamsCallbackBase* mcb_;
+};
+
+class ExpectWriteErrorCallback :
+public WriteCallbackBase {
+public:
+ explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
+ : WriteCallbackBase(mcb) {}
+
+ ~ExpectWriteErrorCallback() {
+ EXPECT_EQ(STATE_FAILED, state);
+ EXPECT_EQ(exception.type_,
+ AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
+ EXPECT_EQ(exception.errno_, 22);
+ // Suppress the assert in ~WriteCallbackBase()
+ state = STATE_SUCCEEDED;
+ }
+};
+
+#ifdef MSG_ERRQUEUE
+/* copied from include/uapi/linux/net_tstamp.h */
+/* SO_TIMESTAMPING gets an integer bit field comprised of these values */
+enum SOF_TIMESTAMPING {
+ SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
+ SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
+ SOF_TIMESTAMPING_OPT_ID = (1 << 7),
+ SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
+ SOF_TIMESTAMPING_TX_ACK = (1 << 9),
+ SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
+};
+
+class WriteCheckTimestampCallback :
+ public WriteCallbackBase {
+public:
+ explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
+ : WriteCallbackBase(mcb) {}
+
+ ~WriteCheckTimestampCallback() {
+ EXPECT_EQ(STATE_SUCCEEDED, state);
+ EXPECT_TRUE(gotTimestamp_);
+ EXPECT_TRUE(gotByteSeq_);
+ }
+
+ void setSocket(
+ const std::shared_ptr<AsyncSSLSocket> &socket) override {
+ WriteCallbackBase::setSocket(socket);
+
+ EXPECT_NE(socket_->getFd(), 0);
+ int flags = SOF_TIMESTAMPING_OPT_ID
+ | SOF_TIMESTAMPING_OPT_TSONLY
+ | SOF_TIMESTAMPING_SOFTWARE;
+ AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
+ int ret = tstampingOpt.apply(socket_->getFd(), flags);
+ EXPECT_EQ(ret, 0);
+ }
+
+ void checkForTimestampNotifications() noexcept {
+ int fd = socket_->getFd();
+ std::vector<char> ctrl(1024, 0);
+ unsigned char data;
+ struct msghdr msg;
+ iovec entry;
+
+ memset(&msg, 0, sizeof(msg));
+ entry.iov_base = &data;
+ entry.iov_len = sizeof(data);
+ msg.msg_iov = &entry;
+ msg.msg_iovlen = 1;
+ msg.msg_control = ctrl.data();
+ msg.msg_controllen = ctrl.size();
+
+ int ret;
+ while (true) {
+ ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
+ if (ret < 0) {
+ if (errno != EAGAIN) {
+ auto errnoCopy = errno;
+ std::cerr << "::recvmsg exited with code " << ret
+ << ", errno: " << errnoCopy << std::endl;
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR,
+ "recvmsg() failed",
+ errnoCopy);
+ exception = ex;
+ }
+ return;
+ }
+
+ for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg != nullptr && cmsg->cmsg_len != 0;
+ cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+ if (cmsg->cmsg_level == SOL_SOCKET &&
+ cmsg->cmsg_type == SCM_TIMESTAMPING) {
+ gotTimestamp_ = true;
+ continue;
+ }
+
+ if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
+ (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
+ gotByteSeq_ = true;
+ continue;
+ }
+ }
+ }
+ }
+
+ bool gotTimestamp_{false};
+ bool gotByteSeq_{false};
};
+#endif // MSG_ERRQUEUE
class ReadCallbackBase :
public AsyncTransportWrapper::ReadCallback {