Implementing a callback interface for folly::AsyncSocket allowing to supply an ancill...
authorMaxim Georgiev <maxgeorg@fb.com>
Fri, 10 Mar 2017 06:19:15 +0000 (22:19 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Mar 2017 06:20:14 +0000 (22:20 -0800)
Summary: Implementing a callback interface for folly::AsyncSocket allowing to supply an ancillary data buffer with msghdr structure to sendmsg() system call.

Reviewed By: afrind

Differential Revision: D4422168

fbshipit-source-id: 29a23b05f704aff796d368f4ac9514c49b7ce578

folly/io/async/AsyncSSLSocket.cpp
folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/test/AsyncSSLSocketTest.cpp
folly/io/async/test/AsyncSSLSocketTest.h
folly/io/async/test/AsyncSocketTest.h
folly/io/async/test/AsyncSocketTest2.cpp

index 3b99a4c..26090b7 100644 (file)
@@ -34,7 +34,6 @@
 #include <folly/io/Cursor.h>
 #include <folly/io/IOBuf.h>
 #include <folly/portability/OpenSSL.h>
-#include <folly/portability/Unistd.h>
 
 using folly::SocketAddress;
 using folly::SSLContext;
@@ -59,6 +58,7 @@ using folly::SSLContext;
 using namespace folly::ssl;
 using folly::ssl::OpenSSLUtils;
 
+
 // We have one single dummy SSL context so that we can implement attach
 // and detach methods in a thread safe fashion without modifying opnessl.
 static SSLContext *dummyCtx = nullptr;
@@ -1624,7 +1624,6 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
 int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
   struct msghdr msg;
   struct iovec iov;
-  int flags = 0;
   AsyncSSLSocket* tsslSock;
 
   iov.iov_base = const_cast<char*>(in);
@@ -1639,23 +1638,28 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
   tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
   CHECK(tsslSock);
 
+  WriteFlags flags = WriteFlags::NONE;
   if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ &&
       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
-    flags = MSG_EOR;
+    flags |= WriteFlags::EOR;
   }
 
-#ifdef MSG_NOSIGNAL
-  flags |= MSG_NOSIGNAL;
-#endif
-
-#ifdef MSG_MORE
   if (tsslSock->corkCurrentWrite_) {
-    flags |= MSG_MORE;
+    flags |= WriteFlags::CORK;
+  }
+
+  int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(flags);
+  msg.msg_controllen =
+      tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags);
+  CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
+           msg.msg_controllen);
+  if (msg.msg_controllen != 0) {
+    msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
+    tsslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control);
   }
-#endif
 
   auto result = tsslSock->sendSocketMessage(
-      OpenSSLUtils::getBioFd(b, nullptr), &msg, flags);
+      OpenSSLUtils::getBioFd(b, nullptr), &msg, msg_flags);
   BIO_clear_retry_flags(b);
   if (!result.exception && result.writeReturn <= 0) {
     if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) {
index 827ee8e..ddbab29 100644 (file)
@@ -185,6 +185,33 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
   struct iovec writeOps_[];     ///< write operation(s) list
 };
 
+int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(folly::WriteFlags flags)
+                                                                      noexcept {
+  int msg_flags = MSG_DONTWAIT;
+
+#ifdef MSG_NOSIGNAL // Linux-only
+  msg_flags |= MSG_NOSIGNAL;
+#ifdef MSG_MORE
+  if (isSet(flags, WriteFlags::CORK)) {
+    // MSG_MORE tells the kernel we have more data to send, so wait for us to
+    // give it the rest of the data rather than immediately sending a partial
+    // frame, even when TCP_NODELAY is enabled.
+    msg_flags |= MSG_MORE;
+  }
+#endif // MSG_MORE
+#endif // MSG_NOSIGNAL
+  if (isSet(flags, WriteFlags::EOR)) {
+    // marks that this is the last byte of a record (response)
+    msg_flags |= MSG_EOR;
+  }
+
+  return msg_flags;
+}
+
+namespace {
+static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback;
+}
+
 AsyncSocket::AsyncSocket()
     : eventBase_(nullptr),
       writeTimeout_(this, nullptr),
@@ -254,6 +281,7 @@ void AsyncSocket::init() {
   shutdownSocketSet_ = nullptr;
   appBytesWritten_ = 0;
   appBytesReceived_ = 0;
+  sendMsgParamCallback_ = &defaultSendMsgParamsCallback;
 }
 
 AsyncSocket::~AsyncSocket() {
@@ -625,6 +653,14 @@ AsyncSocket::ErrMessageCallback* AsyncSocket::getErrMessageCallback() const {
   return errMessageCallback_;
 }
 
+void AsyncSocket::setSendMsgParamCB(SendMsgParamsCallback* callback) {
+  sendMsgParamCallback_ = callback;
+}
+
+AsyncSocket::SendMsgParamsCallback* AsyncSocket::getSendMsgParamsCB() const {
+  return sendMsgParamCallback_;
+}
+
 void AsyncSocket::setReadCB(ReadCallback *callback) {
   VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
           << ", callback=" << callback << ", state=" << state_;
@@ -1363,7 +1399,7 @@ int AsyncSocket::setTCPProfile(int profd) {
 }
 
 void AsyncSocket::ioReady(uint16_t events) noexcept {
-  VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd" << fd_
+  VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd=" << fd_
           << ", events=" << std::hex << events << ", state=" << state_;
   DestructorGuard dg(this);
   assert(events & EventHandler::READ_WRITE);
@@ -2023,25 +2059,19 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
   msg.msg_namelen = 0;
   msg.msg_iov = const_cast<iovec *>(vec);
   msg.msg_iovlen = std::min<size_t>(count, kIovMax);
-  msg.msg_control = nullptr;
-  msg.msg_controllen = 0;
   msg.msg_flags = 0;
+  msg.msg_controllen = sendMsgParamCallback_->getAncillaryDataSize(flags);
+  CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
+           msg.msg_controllen);
 
-  int msg_flags = MSG_DONTWAIT;
-
-#ifdef MSG_NOSIGNAL // Linux-only
-  msg_flags |= MSG_NOSIGNAL;
-  if (isSet(flags, WriteFlags::CORK)) {
-    // MSG_MORE tells the kernel we have more data to send, so wait for us to
-    // give it the rest of the data rather than immediately sending a partial
-    // frame, even when TCP_NODELAY is enabled.
-    msg_flags |= MSG_MORE;
-  }
-#endif
-  if (isSet(flags, WriteFlags::EOR)) {
-    // marks that this is the last byte of a record (response)
-    msg_flags |= MSG_EOR;
+  if (msg.msg_controllen != 0) {
+    msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
+    sendMsgParamCallback_->getAncillaryData(flags, msg.msg_control);
+  } else {
+    msg.msg_control = nullptr;
   }
+  int msg_flags = sendMsgParamCallback_->getFlags(flags);
+
   auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
   auto totalWritten = writeResult.writeReturn;
   if (totalWritten < 0) {
index 3e2adbc..fca560b 100644 (file)
@@ -139,6 +139,77 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     virtual void errMessageError(const AsyncSocketException& ex) noexcept = 0;
   };
 
+  class SendMsgParamsCallback {
+   public:
+    virtual ~SendMsgParamsCallback() = default;
+
+    /**
+     * getFlags() will be invoked to retrieve the desired flags to be passed
+     * to ::sendmsg() system call. This method was intentionally declared
+     * non-virtual, so there is no way to override it. Instead feel free to
+     * override getFlagsImpl(flags, defaultFlags) method instead, and enjoy
+     * the convenience of defaultFlags passed there.
+     *
+     * @param flags     Write flags requested for the given write operation
+     */
+    int getFlags(folly::WriteFlags flags) noexcept {
+      return getFlagsImpl(flags, getDefaultFlags(flags));
+    }
+
+    /**
+     * getAncillaryData() will be invoked to initialize ancillary data
+     * buffer referred by "msg_control" field of msghdr structure passed to
+     * ::sendmsg() system call. The function assumes that the size of buffer
+     * is not smaller than the value returned by getAncillaryDataSize() method
+     * for the same combination of flags.
+     *
+     * @param flags     Write flags requested for the given write operation
+     * @param data      Pointer to ancillary data buffer to initialize.
+     */
+    virtual void getAncillaryData(
+      folly::WriteFlags /*flags*/,
+      void* /*data*/) noexcept {}
+
+    /**
+     * getAncillaryDataSize() will be invoked to retrieve the size of
+     * ancillary data buffer which should be passed to ::sendmsg() system call
+     *
+     * @param flags     Write flags requested for the given write operation
+     */
+    virtual uint32_t getAncillaryDataSize(folly::WriteFlags /*flags*/)
+        noexcept {
+      return 0;
+    }
+
+    static const size_t maxAncillaryDataSize{0x5000};
+
+   private:
+    /**
+     * getFlagsImpl() will be invoked by getFlags(folly::WriteFlags flags)
+     * method to retrieve the flags to be passed to ::sendmsg() system call.
+     * SendMsgParamsCallback::getFlags() is calling this method, and returns
+     * its results directly to the caller in AsyncSocket.
+     * Classes inheriting from SendMsgParamsCallback are welcome to override
+     * this method to force SendMsgParamsCallback to return its own set
+     * of flags.
+     *
+     * @param flags        Write flags requested for the given write operation
+     * @param defaultflags A set of message flags returned by getDefaultFlags()
+     *                     method for the given "flags" mask.
+     */
+    virtual int getFlagsImpl(folly::WriteFlags /*flags*/, int defaultFlags) {
+      return defaultFlags;
+    }
+
+    /**
+     * getDefaultFlags() will be invoked by  getFlags(folly::WriteFlags flags)
+     * to retrieve the default set of flags, and pass them to getFlagsImpl(...)
+     *
+     * @param flags     Write flags requested for the given write operation
+     */
+    int getDefaultFlags(folly::WriteFlags flags) noexcept;
+  };
+
   explicit AsyncSocket();
   /**
    * Create a new unconnected AsyncSocket.
@@ -411,6 +482,20 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    */
   ErrMessageCallback* getErrMessageCallback() const;
 
+  /**
+   * Set a pointer to SendMsgParamsCallback implementation which
+   * will be used to form ::sendmsg() system call parameters
+   *
+   */
+  void setSendMsgParamCB(SendMsgParamsCallback* callback);
+
+  /**
+   * Get a pointer to SendMsgParamsCallback implementation currently
+   * registered with this socket.
+   *
+   */
+  SendMsgParamsCallback* getSendMsgParamsCB() const;
+
   // Read and write methods
   void setReadCB(ReadCallback* callback) override;
   ReadCallback* getReadCallback() const override;
@@ -1010,6 +1095,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
 
   ConnectCallback* connectCallback_;     ///< ConnectCallback
   ErrMessageCallback* errMessageCallback_; ///< TimestampCallback
+  SendMsgParamsCallback*                 ///< Callback for retreaving
+      sendMsgParamCallback_;             ///< ::sendmsg() parameters
   ReadCallback* readCallback_;           ///< ReadCallback
   WriteRequest* writeReqHead_;           ///< Chain of WriteRequests
   WriteRequest* writeReqTail_;           ///< End of WriteRequest chain
@@ -1022,7 +1109,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   // socket.
   std::unique_ptr<IOBuf> preReceivedData_;
 
-  int8_t readErr_{READ_NO_ERROR};       ///< The read error encountered, if any.
+  int8_t readErr_{READ_NO_ERROR};        ///< The read error encountered, if any
 
   std::chrono::steady_clock::time_point connectStartTime_;
   std::chrono::steady_clock::time_point connectEndTime_;
index 41284b6..fbbf999 100644 (file)
@@ -32,6 +32,7 @@
 #include <folly/io/Cursor.h>
 #include <openssl/bio.h>
 #include <sys/types.h>
+#include <sys/utsname.h>
 #include <fstream>
 #include <iostream>
 #include <list>
@@ -1958,6 +1959,120 @@ TEST(AsyncSSLSocketTest, TestPreReceivedData) {
       serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
 }
 
+/**
+ * Test overriding the flags passed to "sendmsg()" system call,
+ * and verifying that write requests fail properly.
+ */
+TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
+  // Start listening on a local port
+  SendMsgFlagsCallback msgCallback;
+  ExpectWriteErrorCallback writeCallback(&msgCallback);
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+  // connect
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+                                                 sslContext);
+  socket->open();
+
+  // Setting flags to "-1" to trigger "Invalid argument" error
+  // on attempt to use this flags in sendmsg() system call.
+  msgCallback.resetFlags(-1);
+
+  // write()
+  std::vector<uint8_t> buf(128, 'a');
+  ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
+
+  // close()
+  socket->close();
+
+  cerr << "SendMsgParamsCallback test completed" << endl;
+}
+
+#ifdef MSG_ERRQUEUE
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server.
+ */
+TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
+  // This test requires Linux kernel v4.6 or later
+  struct utsname s_uname;
+  memset(&s_uname, 0, sizeof(s_uname));
+  ASSERT_EQ(uname(&s_uname), 0);
+  int major, minor;
+  folly::StringPiece extra;
+  if (folly::split<false>(
+        '.', std::string(s_uname.release) + ".", major, minor, extra)) {
+    if (major < 4 || (major == 4 && minor < 6)) {
+      LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
+                << "kernel ver. " << s_uname.release << " detected).";
+      return;
+    }
+  }
+
+  // Start listening on a local port
+  SendMsgDataCallback msgCallback;
+  WriteCheckTimestampCallback writeCallback(&msgCallback);
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+  // connect
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
+                                                 sslContext);
+  socket->open();
+
+  // Adding MSG_EOR flag to the message flags - it'll trigger
+  // timestamp generation for the last byte of the message.
+  msgCallback.resetFlags(MSG_DONTWAIT|MSG_NOSIGNAL|MSG_EOR);
+
+  // Init ancillary data buffer to trigger timestamp notification
+  union {
+    uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
+    struct cmsghdr cmsg;
+  } u;
+  u.cmsg.cmsg_level = SOL_SOCKET;
+  u.cmsg.cmsg_type = SO_TIMESTAMPING;
+  u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
+  uint32_t flags =
+      SOF_TIMESTAMPING_TX_SCHED |
+      SOF_TIMESTAMPING_TX_SOFTWARE |
+      SOF_TIMESTAMPING_TX_ACK;
+  memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
+  std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
+  memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
+  msgCallback.resetData(std::move(ctrl));
+
+  // write()
+  std::vector<uint8_t> buf(128, 'a');
+  socket->write(buf.data(), buf.size());
+
+  // read()
+  std::vector<uint8_t> readbuf(buf.size());
+  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+  EXPECT_EQ(bytesRead, buf.size());
+  EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
+
+  writeCallback.checkForTimestampNotifications();
+
+  // close()
+  socket->close();
+
+  cerr << "SendMsgDataCallback test completed" << endl;
+}
+#endif // MSG_ERRQUEUE
+
 #endif
 
 } // namespace
index 7965f8e..ff70d3b 100644 (file)
@@ -46,24 +46,106 @@ namespace folly {
 // are responsible for setting the succeeded state properly before the
 // destructors are called.
 
+class SendMsgParamsCallbackBase :
+      public folly::AsyncSocket::SendMsgParamsCallback {
+ public:
+  SendMsgParamsCallbackBase() {}
+
+  void setSocket(
+    const std::shared_ptr<AsyncSSLSocket> &socket) {
+    socket_ = socket;
+    oldCallback_ = socket_->getSendMsgParamsCB();
+    socket_->setSendMsgParamCB(this);
+  }
+
+  int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+                                                                     override {
+    return oldCallback_->getFlags(flags);
+  }
+
+  void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+    oldCallback_->getAncillaryData(flags, data);
+  }
+
+  uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+    return oldCallback_->getAncillaryDataSize(flags);
+  }
+
+  std::shared_ptr<AsyncSSLSocket> socket_;
+  folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
+};
+
+class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
+ public:
+  SendMsgFlagsCallback() {}
+
+  void resetFlags(int flags) {
+    flags_ = flags;
+  }
+
+  int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+                                                                  override {
+    if (flags_) {
+      return flags_;
+    } else {
+      return oldCallback_->getFlags(flags);
+    }
+  }
+
+  int flags_{0};
+};
+
+class SendMsgDataCallback : public SendMsgFlagsCallback {
+ public:
+  SendMsgDataCallback() {}
+
+  void resetData(std::vector<char>&& data) {
+    ancillaryData_.swap(data);
+  }
+
+  void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+    if (ancillaryData_.size()) {
+      std::cerr << "getAncillaryData: copying data" << std::endl;
+      memcpy(data, ancillaryData_.data(), ancillaryData_.size());
+    } else {
+      oldCallback_->getAncillaryData(flags, data);
+    }
+  }
+
+  uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+    if (ancillaryData_.size()) {
+      std::cerr << "getAncillaryDataSize: returning size" << std::endl;
+      return ancillaryData_.size();
+    } else {
+      return oldCallback_->getAncillaryDataSize(flags);
+    }
+  }
+
+  std::vector<char> ancillaryData_;
+};
+
 class WriteCallbackBase :
 public AsyncTransportWrapper::WriteCallback {
 public:
-  WriteCallbackBase()
+  explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
       : state(STATE_WAITING)
       , bytesWritten(0)
-      , exception(AsyncSocketException::UNKNOWN, "none") {}
+      , exception(AsyncSocketException::UNKNOWN, "none")
+      , mcb_(mcb) {}
 
   ~WriteCallbackBase() {
     EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
-  void setSocket(
+  virtual void setSocket(
     const std::shared_ptr<AsyncSSLSocket> &socket) {
     socket_ = socket;
+    if (mcb_) {
+      mcb_->setSocket(socket);
+    }
   }
 
-  void writeSuccess() noexcept override {
+  virtual void writeSuccess() noexcept override {
     std::cerr << "writeSuccess" << std::endl;
     state = STATE_SUCCEEDED;
   }
@@ -84,7 +166,116 @@ public:
   StateEnum state;
   size_t bytesWritten;
   AsyncSocketException exception;
+  SendMsgParamsCallbackBase* mcb_;
+};
+
+class ExpectWriteErrorCallback :
+public WriteCallbackBase {
+public:
+  explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
+      : WriteCallbackBase(mcb) {}
+
+  ~ExpectWriteErrorCallback() {
+    EXPECT_EQ(STATE_FAILED, state);
+    EXPECT_EQ(exception.type_,
+             AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
+    EXPECT_EQ(exception.errno_, 22);
+    // Suppress the assert in  ~WriteCallbackBase()
+    state = STATE_SUCCEEDED;
+  }
+};
+
+#ifdef MSG_ERRQUEUE
+/* copied from include/uapi/linux/net_tstamp.h */
+/* SO_TIMESTAMPING gets an integer bit field comprised of these values */
+enum SOF_TIMESTAMPING {
+  SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
+  SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
+  SOF_TIMESTAMPING_OPT_ID = (1 << 7),
+  SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
+  SOF_TIMESTAMPING_TX_ACK = (1 << 9),
+  SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
+};
+
+class WriteCheckTimestampCallback :
+ public WriteCallbackBase {
+public:
+  explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
+    : WriteCallbackBase(mcb) {}
+
+  ~WriteCheckTimestampCallback() {
+    EXPECT_EQ(STATE_SUCCEEDED, state);
+    EXPECT_TRUE(gotTimestamp_);
+    EXPECT_TRUE(gotByteSeq_);
+  }
+
+  void setSocket(
+    const std::shared_ptr<AsyncSSLSocket> &socket) override {
+    WriteCallbackBase::setSocket(socket);
+
+    EXPECT_NE(socket_->getFd(), 0);
+    int flags = SOF_TIMESTAMPING_OPT_ID
+                | SOF_TIMESTAMPING_OPT_TSONLY
+                | SOF_TIMESTAMPING_SOFTWARE;
+    AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
+    int ret = tstampingOpt.apply(socket_->getFd(), flags);
+    EXPECT_EQ(ret, 0);
+  }
+
+  void checkForTimestampNotifications() noexcept {
+    int fd = socket_->getFd();
+    std::vector<char> ctrl(1024, 0);
+    unsigned char data;
+    struct msghdr msg;
+    iovec entry;
+
+    memset(&msg, 0, sizeof(msg));
+    entry.iov_base = &data;
+    entry.iov_len = sizeof(data);
+    msg.msg_iov = &entry;
+    msg.msg_iovlen = 1;
+    msg.msg_control = ctrl.data();
+    msg.msg_controllen = ctrl.size();
+
+    int ret;
+    while (true) {
+      ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
+      if (ret < 0) {
+        if (errno != EAGAIN) {
+          auto errnoCopy = errno;
+          std::cerr << "::recvmsg exited with code " << ret
+                    << ", errno: " << errnoCopy << std::endl;
+          AsyncSocketException ex(
+            AsyncSocketException::INTERNAL_ERROR,
+            "recvmsg() failed",
+            errnoCopy);
+          exception = ex;
+        }
+        return;
+      }
+
+      for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+           cmsg != nullptr && cmsg->cmsg_len != 0;
+           cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+        if (cmsg->cmsg_level == SOL_SOCKET &&
+            cmsg->cmsg_type == SCM_TIMESTAMPING) {
+          gotTimestamp_ = true;
+          continue;
+        }
+
+        if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
+            (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
+          gotByteSeq_ = true;
+          continue;
+        }
+      }
+    }
+  }
+
+  bool gotTimestamp_{false};
+  bool gotByteSeq_{false};
 };
+#endif // MSG_ERRQUEUE
 
 class ReadCallbackBase :
 public AsyncTransportWrapper::ReadCallback {
index d69c851..254c14b 100644 (file)
@@ -229,6 +229,64 @@ class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
   bool gotByteSeq_{false};
 };
 
+class TestSendMsgParamsCallback :
+    public folly::AsyncSocket::SendMsgParamsCallback {
+ public:
+  TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
+  : flags_(flags),
+    writeFlags_(folly::WriteFlags::NONE),
+    dataSize_(dataSize),
+    data_(data),
+    queriedFlags_(false),
+    queriedData_(false)
+  {}
+
+  void reset(int flags) {
+    flags_ = flags;
+    writeFlags_ = folly::WriteFlags::NONE;
+    queriedFlags_ = false;
+    queriedData_ = false;
+  }
+
+  int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+                                                                  override {
+    queriedFlags_ = true;
+    if (writeFlags_ == folly::WriteFlags::NONE) {
+      writeFlags_ = flags;
+    } else {
+      assert(flags == writeFlags_);
+    }
+    return flags_;
+  }
+
+  void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+    queriedData_ = true;
+    if (writeFlags_ == folly::WriteFlags::NONE) {
+      writeFlags_ = flags;
+    } else {
+      assert(flags == writeFlags_);
+    }
+    assert(data != nullptr);
+    memcpy(data, data_, dataSize_);
+  }
+
+  uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+    if (writeFlags_ == folly::WriteFlags::NONE) {
+      writeFlags_ = flags;
+    } else {
+      assert(flags == writeFlags_);
+    }
+    return dataSize_;
+  }
+
+  int flags_;
+  folly::WriteFlags writeFlags_;
+  uint32_t dataSize_;
+  void* data_;
+  bool queriedFlags_;
+  bool queriedData_;
+};
+
 class TestServer {
  public:
   // Create a TestServer.
index 864a403..a198872 100644 (file)
@@ -2823,21 +2823,11 @@ TEST(AsyncSocketTest, EvbCallbacks) {
 /* copied from include/uapi/linux/net_tstamp.h */
 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
 enum SOF_TIMESTAMPING {
-  // SOF_TIMESTAMPING_TX_HARDWARE = (1 << 0),
-  // SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
-  // SOF_TIMESTAMPING_RX_HARDWARE = (1 << 2),
-  // SOF_TIMESTAMPING_RX_SOFTWARE = (1 << 3),
   SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
-  // SOF_TIMESTAMPING_SYS_HARDWARE = (1 << 5),
-  // SOF_TIMESTAMPING_RAW_HARDWARE = (1 << 6),
   SOF_TIMESTAMPING_OPT_ID = (1 << 7),
   SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
-  // SOF_TIMESTAMPING_TX_ACK = (1 << 9),
   SOF_TIMESTAMPING_OPT_CMSG = (1 << 10),
   SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
-
-  // SOF_TIMESTAMPING_LAST = SOF_TIMESTAMPING_OPT_TSONLY,
-  // SOF_TIMESTAMPING_MASK = (SOF_TIMESTAMPING_LAST - 1) | SOF_TIMESTAMPING_LAST,
 };
 TEST(AsyncSocketTest, ErrMessageCallback) {
   TestServer server;
@@ -3039,3 +3029,167 @@ TEST(AsyncSocket, PreReceivedDataTakeover) {
 
   evb.loop();
 }
+
+TEST(AsyncSocketTest, SendMessageFlags) {
+  TestServer server;
+  TestSendMsgParamsCallback sendMsgCB(
+      MSG_DONTWAIT|MSG_NOSIGNAL|MSG_MORE, 0, nullptr);
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+
+  evb.loop();
+  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
+
+  // Set SendMsgParamsCallback
+  socket->setSendMsgParamCB(&sendMsgCB);
+  ASSERT_EQ(socket->getSendMsgParamsCB(), &sendMsgCB);
+
+  // Write the first portion of data. This data is expected to be
+  // sent out immediately.
+  std::vector<uint8_t> buf(128, 'a');
+  WriteCallback wcb;
+  sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL);
+  socket->write(&wcb, buf.data(), buf.size());
+  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+  ASSERT_TRUE(sendMsgCB.queriedFlags_);
+  ASSERT_FALSE(sendMsgCB.queriedData_);
+
+  // Using different flags for the second write operation.
+  // MSG_MORE flag is expected to delay sending this
+  // data to the wire.
+  sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE);
+  socket->write(&wcb, buf.data(), buf.size());
+  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+  ASSERT_TRUE(sendMsgCB.queriedFlags_);
+  ASSERT_FALSE(sendMsgCB.queriedData_);
+
+  // Make sure the accepted socket saw only the data from
+  // the first write request.
+  std::vector<uint8_t> readbuf(2 * buf.size());
+  uint32_t bytesRead = acceptedSocket->read(readbuf.data(), readbuf.size());
+  ASSERT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
+  ASSERT_EQ(bytesRead, buf.size());
+
+  // Make sure the server got a connection and received the data
+  acceptedSocket->close();
+  socket->close();
+
+  ASSERT_TRUE(socket->isClosedBySelf());
+  ASSERT_FALSE(socket->isClosedByPeer());
+}
+
+TEST(AsyncSocketTest, SendMessageAncillaryData) {
+  struct sockaddr_un addr = {AF_UNIX,
+                             "AsyncSocketTest.SendMessageAncillaryData\0"};
+
+  // Clean up the name in the name space we're going to use
+  ASSERT_FALSE(remove(addr.sun_path) == -1 && errno != ENOENT);
+
+  // Set up listening socket
+  int lfd = fsp::socket(AF_UNIX, SOCK_STREAM, 0);
+  ASSERT_NE(lfd, -1);
+  ASSERT_NE(bind(lfd, (struct sockaddr*)&addr, sizeof(addr)), -1)
+      << "Bind failed: " << errno;
+
+  // Create the connecting socket
+  int csd = fsp::socket(AF_UNIX, SOCK_STREAM, 0);
+  ASSERT_NE(csd, -1);
+
+  // Listen for incoming connect
+  ASSERT_NE(listen(lfd, 5), -1);
+
+  // Connect to the listening socket
+  ASSERT_NE(fsp::connect(csd, (struct sockaddr*)&addr, sizeof(addr)), -1)
+      << "Connect request failed: " << errno;
+
+  // Accept the connection
+  int sfd = accept(lfd, nullptr, nullptr);
+  ASSERT_NE(sfd, -1);
+
+  // Instantiate AsyncSocket object for the connected socket
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, csd);
+
+  // Open a temporary file and write a magic string to it
+  // We'll transfer the file handle to test the message parameters
+  // callback logic.
+  int tmpfd = open("/var/tmp", O_RDWR | O_TMPFILE);
+  ASSERT_NE(tmpfd, -1) << "Failed to open a temporary file";
+  std::string magicString("Magic string");
+  ASSERT_EQ(write(tmpfd, magicString.c_str(), magicString.length()),
+            magicString.length());
+
+  // Send message
+  union {
+    // Space large enough to hold an 'int'
+    char control[CMSG_SPACE(sizeof(int))];
+    struct cmsghdr cmh;
+  } s_u;
+  s_u.cmh.cmsg_len = CMSG_LEN(sizeof(int));
+  s_u.cmh.cmsg_level = SOL_SOCKET;
+  s_u.cmh.cmsg_type = SCM_RIGHTS;
+  memcpy(CMSG_DATA(&s_u.cmh), &tmpfd, sizeof(int));
+
+  // Set up the callback providing message parameters
+  TestSendMsgParamsCallback sendMsgCB(
+      MSG_DONTWAIT | MSG_NOSIGNAL, sizeof(s_u.control), s_u.control);
+  socket->setSendMsgParamCB(&sendMsgCB);
+
+  // We must transmit at least 1 byte of real data in order
+  // to send ancillary data
+  int s_data = 12345;
+  WriteCallback wcb;
+  socket->write(&wcb, &s_data, sizeof(s_data));
+  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+
+  // Receive the message
+  union {
+    // Space large enough to hold an 'int'
+    char control[CMSG_SPACE(sizeof(int))];
+    struct cmsghdr cmh;
+  } r_u;
+  struct msghdr msgh;
+  struct iovec iov;
+  int r_data = 0;
+
+  msgh.msg_control = r_u.control;
+  msgh.msg_controllen = sizeof(r_u.control);
+  msgh.msg_name = nullptr;
+  msgh.msg_namelen = 0;
+  msgh.msg_iov = &iov;
+  msgh.msg_iovlen = 1;
+  iov.iov_base = &r_data;
+  iov.iov_len = sizeof(r_data);
+
+  // Receive data
+  ASSERT_NE(recvmsg(sfd, &msgh, 0), -1) << "recvmsg failed: " << errno;
+
+  // Validate the received message
+  ASSERT_EQ(r_u.cmh.cmsg_len, CMSG_LEN(sizeof(int)));
+  ASSERT_EQ(r_u.cmh.cmsg_level, SOL_SOCKET);
+  ASSERT_EQ(r_u.cmh.cmsg_type, SCM_RIGHTS);
+  ASSERT_EQ(r_data, s_data);
+  int fd = 0;
+  memcpy(&fd, CMSG_DATA(&r_u.cmh), sizeof(int));
+  ASSERT_NE(fd, 0);
+
+  std::vector<uint8_t> transferredMagicString(magicString.length() + 1, 0);
+
+  // Reposition to the beginning of the file
+  ASSERT_EQ(0, lseek(fd, 0, SEEK_SET));
+
+  // Read the magic string back, and compare it with the original
+  ASSERT_EQ(
+      magicString.length(),
+      read(fd, transferredMagicString.data(), transferredMagicString.size()));
+  ASSERT_TRUE(std::equal(
+      magicString.begin(),
+      magicString.end(),
+      transferredMagicString.begin()));
+}