Add connection event callback to AsyncServerSocket
[folly.git] / folly / io / async / test / AsyncSocketTest2.cpp
index 74da66215c375acfb149ae9d7447e14bac28d5b7..2c822eef11c5b8e27eec46b73ab2de1e2ae6e84c 100644 (file)
@@ -17,6 +17,7 @@
 #include <folly/io/async/AsyncSocket.h>
 #include <folly/io/async/AsyncTimeout.h>
 #include <folly/io/async/EventBase.h>
+#include <folly/RWSpinLock.h>
 #include <folly/SocketAddress.h>
 
 #include <folly/io/IOBuf.h>
@@ -1452,6 +1453,113 @@ TEST(AsyncSocket, ConnectReadUninstallRead) {
 ///////////////////////////////////////////////////////////////////////////
 // AsyncServerSocket tests
 ///////////////////////////////////////////////////////////////////////////
+namespace {
+/**
+ * Helper ConnectionEventCallback class for the test code.
+ * It maintains counters protected by a spin lock.
+ */
+class TestConnectionEventCallback :
+  public AsyncServerSocket::ConnectionEventCallback {
+ public:
+  virtual void onConnectionAccepted(
+      const int socket,
+      const SocketAddress& addr) noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    connectionAccepted_++;
+  }
+
+  virtual void onConnectionAcceptError(const int err) noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    connectionAcceptedError_++;
+  }
+
+  virtual void onConnectionDropped(
+      const int socket,
+      const SocketAddress& addr) noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    connectionDropped_++;
+  }
+
+  virtual void onConnectionEnqueuedForAcceptCallback(
+      const int socket,
+      const SocketAddress& addr) noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    connectionEnqueuedForAcceptCallback_++;
+  }
+
+  virtual void onConnectionDequeuedByAcceptCallback(
+      const int socket,
+      const SocketAddress& addr) noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    connectionDequeuedByAcceptCallback_++;
+  }
+
+  virtual void onBackoffStarted() noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    backoffStarted_++;
+  }
+
+  virtual void onBackoffEnded() noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    backoffEnded_++;
+  }
+
+  virtual void onBackoffError() noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    backoffError_++;
+  }
+
+  unsigned int getConnectionAccepted() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return connectionAccepted_;
+  }
+
+  unsigned int getConnectionAcceptedError() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return connectionAcceptedError_;
+  }
+
+  unsigned int getConnectionDropped() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return connectionDropped_;
+  }
+
+  unsigned int getConnectionEnqueuedForAcceptCallback() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return connectionEnqueuedForAcceptCallback_;
+  }
+
+  unsigned int getConnectionDequeuedByAcceptCallback() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return connectionDequeuedByAcceptCallback_;
+  }
+
+  unsigned int getBackoffStarted() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return backoffStarted_;
+  }
+
+  unsigned int getBackoffEnded() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return backoffEnded_;
+  }
+
+  unsigned int getBackoffError() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return backoffError_;
+  }
+
+ private:
+  mutable folly::RWSpinLock spinLock_;
+  unsigned int connectionAccepted_{0};
+  unsigned int connectionAcceptedError_{0};
+  unsigned int connectionDropped_{0};
+  unsigned int connectionEnqueuedForAcceptCallback_{0};
+  unsigned int connectionDequeuedByAcceptCallback_{0};
+  unsigned int backoffStarted_{0};
+  unsigned int backoffEnded_{0};
+  unsigned int backoffError_{0};
+};
 
 /**
  * Helper AcceptCallback class for the test code
@@ -1552,6 +1660,7 @@ class TestAcceptCallback : public AsyncServerSocket::AcceptCallback {
 
   std::deque<EventInfo> events_;
 };
+}
 
 /**
  * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
@@ -2043,3 +2152,46 @@ TEST(AsyncSocketTest, UnixDomainSocketTest) {
   int flags = fcntl(fd, F_GETFL, 0);
   CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK);
 }
+
+TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
+  EventBase eventBase;
+  TestConnectionEventCallback connectionEventCallback;
+
+  // Create a server socket
+  std::shared_ptr<AsyncServerSocket> serverSocket(
+      AsyncServerSocket::newSocket(&eventBase));
+  serverSocket->setConnectionEventCallback(&connectionEventCallback);
+  serverSocket->bind(0);
+  serverSocket->listen(16);
+  folly::SocketAddress serverAddress;
+  serverSocket->getAddress(&serverAddress);
+
+  // Add a callback to accept one connection then stop the loop
+  TestAcceptCallback acceptCallback;
+  acceptCallback.setConnectionAcceptedFn(
+    [&](int fd, const folly::SocketAddress& addr) {
+      serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+    });
+  acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
+    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+  });
+  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->startAccepting();
+
+  // Connect to the server socket
+  std::shared_ptr<AsyncSocket> socket(
+      AsyncSocket::newSocket(&eventBase, serverAddress));
+
+  eventBase.loop();
+
+  // Validate the connection event counters
+  ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
+  ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
+  ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
+  ASSERT_EQ(
+      connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1);
+  ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1);
+  ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
+  ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
+  ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
+}