Implementing a callback interface for folly::AsyncSocket allowing to supply an ancill...
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
index 7965f8e451932cf6034cd52f98826fbf23ac6a27..ff70d3b8749683c34cdcd484bf78c0030e93747d 100644 (file)
@@ -46,24 +46,106 @@ namespace folly {
 // are responsible for setting the succeeded state properly before the
 // destructors are called.
 
+class SendMsgParamsCallbackBase :
+      public folly::AsyncSocket::SendMsgParamsCallback {
+ public:
+  SendMsgParamsCallbackBase() {}
+
+  void setSocket(
+    const std::shared_ptr<AsyncSSLSocket> &socket) {
+    socket_ = socket;
+    oldCallback_ = socket_->getSendMsgParamsCB();
+    socket_->setSendMsgParamCB(this);
+  }
+
+  int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+                                                                     override {
+    return oldCallback_->getFlags(flags);
+  }
+
+  void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+    oldCallback_->getAncillaryData(flags, data);
+  }
+
+  uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+    return oldCallback_->getAncillaryDataSize(flags);
+  }
+
+  std::shared_ptr<AsyncSSLSocket> socket_;
+  folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
+};
+
+class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
+ public:
+  SendMsgFlagsCallback() {}
+
+  void resetFlags(int flags) {
+    flags_ = flags;
+  }
+
+  int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+                                                                  override {
+    if (flags_) {
+      return flags_;
+    } else {
+      return oldCallback_->getFlags(flags);
+    }
+  }
+
+  int flags_{0};
+};
+
+class SendMsgDataCallback : public SendMsgFlagsCallback {
+ public:
+  SendMsgDataCallback() {}
+
+  void resetData(std::vector<char>&& data) {
+    ancillaryData_.swap(data);
+  }
+
+  void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+    if (ancillaryData_.size()) {
+      std::cerr << "getAncillaryData: copying data" << std::endl;
+      memcpy(data, ancillaryData_.data(), ancillaryData_.size());
+    } else {
+      oldCallback_->getAncillaryData(flags, data);
+    }
+  }
+
+  uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+    if (ancillaryData_.size()) {
+      std::cerr << "getAncillaryDataSize: returning size" << std::endl;
+      return ancillaryData_.size();
+    } else {
+      return oldCallback_->getAncillaryDataSize(flags);
+    }
+  }
+
+  std::vector<char> ancillaryData_;
+};
+
 class WriteCallbackBase :
 public AsyncTransportWrapper::WriteCallback {
 public:
-  WriteCallbackBase()
+  explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
       : state(STATE_WAITING)
       , bytesWritten(0)
-      , exception(AsyncSocketException::UNKNOWN, "none") {}
+      , exception(AsyncSocketException::UNKNOWN, "none")
+      , mcb_(mcb) {}
 
   ~WriteCallbackBase() {
     EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
-  void setSocket(
+  virtual void setSocket(
     const std::shared_ptr<AsyncSSLSocket> &socket) {
     socket_ = socket;
+    if (mcb_) {
+      mcb_->setSocket(socket);
+    }
   }
 
-  void writeSuccess() noexcept override {
+  virtual void writeSuccess() noexcept override {
     std::cerr << "writeSuccess" << std::endl;
     state = STATE_SUCCEEDED;
   }
@@ -84,7 +166,116 @@ public:
   StateEnum state;
   size_t bytesWritten;
   AsyncSocketException exception;
+  SendMsgParamsCallbackBase* mcb_;
+};
+
+class ExpectWriteErrorCallback :
+public WriteCallbackBase {
+public:
+  explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
+      : WriteCallbackBase(mcb) {}
+
+  ~ExpectWriteErrorCallback() {
+    EXPECT_EQ(STATE_FAILED, state);
+    EXPECT_EQ(exception.type_,
+             AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
+    EXPECT_EQ(exception.errno_, 22);
+    // Suppress the assert in  ~WriteCallbackBase()
+    state = STATE_SUCCEEDED;
+  }
+};
+
+#ifdef MSG_ERRQUEUE
+/* copied from include/uapi/linux/net_tstamp.h */
+/* SO_TIMESTAMPING gets an integer bit field comprised of these values */
+enum SOF_TIMESTAMPING {
+  SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
+  SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
+  SOF_TIMESTAMPING_OPT_ID = (1 << 7),
+  SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
+  SOF_TIMESTAMPING_TX_ACK = (1 << 9),
+  SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
+};
+
+class WriteCheckTimestampCallback :
+ public WriteCallbackBase {
+public:
+  explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
+    : WriteCallbackBase(mcb) {}
+
+  ~WriteCheckTimestampCallback() {
+    EXPECT_EQ(STATE_SUCCEEDED, state);
+    EXPECT_TRUE(gotTimestamp_);
+    EXPECT_TRUE(gotByteSeq_);
+  }
+
+  void setSocket(
+    const std::shared_ptr<AsyncSSLSocket> &socket) override {
+    WriteCallbackBase::setSocket(socket);
+
+    EXPECT_NE(socket_->getFd(), 0);
+    int flags = SOF_TIMESTAMPING_OPT_ID
+                | SOF_TIMESTAMPING_OPT_TSONLY
+                | SOF_TIMESTAMPING_SOFTWARE;
+    AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
+    int ret = tstampingOpt.apply(socket_->getFd(), flags);
+    EXPECT_EQ(ret, 0);
+  }
+
+  void checkForTimestampNotifications() noexcept {
+    int fd = socket_->getFd();
+    std::vector<char> ctrl(1024, 0);
+    unsigned char data;
+    struct msghdr msg;
+    iovec entry;
+
+    memset(&msg, 0, sizeof(msg));
+    entry.iov_base = &data;
+    entry.iov_len = sizeof(data);
+    msg.msg_iov = &entry;
+    msg.msg_iovlen = 1;
+    msg.msg_control = ctrl.data();
+    msg.msg_controllen = ctrl.size();
+
+    int ret;
+    while (true) {
+      ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
+      if (ret < 0) {
+        if (errno != EAGAIN) {
+          auto errnoCopy = errno;
+          std::cerr << "::recvmsg exited with code " << ret
+                    << ", errno: " << errnoCopy << std::endl;
+          AsyncSocketException ex(
+            AsyncSocketException::INTERNAL_ERROR,
+            "recvmsg() failed",
+            errnoCopy);
+          exception = ex;
+        }
+        return;
+      }
+
+      for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+           cmsg != nullptr && cmsg->cmsg_len != 0;
+           cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+        if (cmsg->cmsg_level == SOL_SOCKET &&
+            cmsg->cmsg_type == SCM_TIMESTAMPING) {
+          gotTimestamp_ = true;
+          continue;
+        }
+
+        if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
+            (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
+          gotByteSeq_ = true;
+          continue;
+        }
+      }
+    }
+  }
+
+  bool gotTimestamp_{false};
+  bool gotByteSeq_{false};
 };
+#endif // MSG_ERRQUEUE
 
 class ReadCallbackBase :
 public AsyncTransportWrapper::ReadCallback {