Add connection event callback to AsyncServerSocket
authorMohammad Husain <mhusain@fb.com>
Fri, 16 Oct 2015 04:13:25 +0000 (21:13 -0700)
committerfacebook-github-bot-4 <folly-bot@fb.com>
Fri, 16 Oct 2015 05:20:28 +0000 (22:20 -0700)
Summary: Adding a callback to AsyncServerSocket to get notified of client connection events. This can be used for example to record stats about these events.

Reviewed By: @afrind

Differential Revision: D2544776

fb-gh-sync-id: 20d22cfc939c5b937abec2b600c10b7228923ff3

folly/io/async/AsyncServerSocket.cpp
folly/io/async/AsyncServerSocket.h
folly/io/async/test/AsyncSocketTest2.cpp

index 75e0da30ad45a8cfc01adf373c8dae438294bc27..b2b2d24be7510999634bf48aabdd529bfb050439 100644 (file)
@@ -91,6 +91,10 @@ void AsyncServerSocket::RemoteAcceptor::messageAvailable(
   switch (msg.type) {
     case MessageType::MSG_NEW_CONN:
     {
+      if (connectionEventCallback_) {
+        connectionEventCallback_->onConnectionDequeuedByAcceptCallback(
+            msg.fd, msg.address);
+      }
       callback_->connectionAccepted(msg.fd, msg.address);
       break;
     }
@@ -515,7 +519,7 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback,
   // callback more efficiently without having to use a notification queue.
   RemoteAcceptor* acceptor = nullptr;
   try {
-    acceptor = new RemoteAcceptor(callback);
+    acceptor = new RemoteAcceptor(callback, connectionEventCallback_);
     acceptor->start(eventBase, maxAtOnce, maxNumMsgsInQueue_);
   } catch (...) {
     callbacks_.pop_back();
@@ -722,6 +726,10 @@ void AsyncServerSocket::handlerReady(
 
     address.setFromSockaddr(saddr, addrLen);
 
+    if (clientSocket >= 0 && connectionEventCallback_) {
+      connectionEventCallback_->onConnectionAccepted(clientSocket, address);
+    }
+
     std::chrono::time_point<std::chrono::steady_clock> nowMs =
       std::chrono::steady_clock::now();
     auto timeSinceLastAccept = std::max<int64_t>(
@@ -737,6 +745,10 @@ void AsyncServerSocket::handlerReady(
         ++numDroppedConnections_;
         if (clientSocket >= 0) {
           closeNoInt(clientSocket);
+          if (connectionEventCallback_) {
+            connectionEventCallback_->onConnectionDropped(clientSocket,
+                                                          address);
+          }
         }
         continue;
       }
@@ -760,6 +772,9 @@ void AsyncServerSocket::handlerReady(
       } else {
         dispatchError("accept() failed", errno);
       }
+      if (connectionEventCallback_) {
+        connectionEventCallback_->onConnectionAcceptError(errno);
+      }
       return;
     }
 
@@ -769,6 +784,9 @@ void AsyncServerSocket::handlerReady(
       closeNoInt(clientSocket);
       dispatchError("failed to set accepted socket to non-blocking mode",
                     errno);
+      if (connectionEventCallback_) {
+        connectionEventCallback_->onConnectionDropped(clientSocket, address);
+      }
       return;
     }
 #endif
@@ -795,6 +813,7 @@ void AsyncServerSocket::dispatchSocket(int socket,
     return;
   }
 
+  const SocketAddress addr(address);
   // Create a message to send over the notification queue
   QueueMessage msg;
   msg.type = MessageType::MSG_NEW_CONN;
@@ -804,9 +823,13 @@ void AsyncServerSocket::dispatchSocket(int socket,
   // Loop until we find a free queue to write to
   while (true) {
     if (info->consumer->getQueue()->tryPutMessageNoThrow(std::move(msg))) {
+      if (connectionEventCallback_) {
+        connectionEventCallback_->onConnectionEnqueuedForAcceptCallback(socket,
+                                                                        addr);
+      }
       // Success! return.
       return;
-   }
+    }
 
     // We couldn't add to queue.  Fall through to below
 
@@ -831,6 +854,9 @@ void AsyncServerSocket::dispatchSocket(int socket,
       LOG(ERROR) << "failed to dispatch newly accepted socket:"
                  << " all accept callback queues are full";
       closeNoInt(socket);
+      if (connectionEventCallback_) {
+        connectionEventCallback_->onConnectionDropped(socket, addr);
+      }
       return;
     }
 
@@ -886,6 +912,9 @@ void AsyncServerSocket::enterBackoff() {
       // since we won't be able to re-enable ourselves later.
       LOG(ERROR) << "failed to allocate AsyncServerSocket backoff"
                  << " timer; unable to temporarly pause accepting";
+      if (connectionEventCallback_) {
+        connectionEventCallback_->onBackoffError();
+      }
       return;
     }
   }
@@ -903,6 +932,9 @@ void AsyncServerSocket::enterBackoff() {
   if (!backoffTimeout_->scheduleTimeout(timeoutMS)) {
     LOG(ERROR) << "failed to schedule AsyncServerSocket backoff timer;"
                << "unable to temporarly pause accepting";
+    if (connectionEventCallback_) {
+      connectionEventCallback_->onBackoffError();
+    }
     return;
   }
 
@@ -912,6 +944,9 @@ void AsyncServerSocket::enterBackoff() {
   for (auto& handler : sockets_) {
     handler.unregisterHandler();
   }
+  if (connectionEventCallback_) {
+    connectionEventCallback_->onBackoffStarted();
+  }
 }
 
 void AsyncServerSocket::backoffTimeoutExpired() {
@@ -924,6 +959,9 @@ void AsyncServerSocket::backoffTimeoutExpired() {
 
   // If all of the callbacks were removed, we shouldn't re-enable accepts
   if (callbacks_.empty()) {
+    if (connectionEventCallback_) {
+      connectionEventCallback_->onBackoffEnded();
+    }
     return;
   }
 
@@ -942,6 +980,9 @@ void AsyncServerSocket::backoffTimeoutExpired() {
       abort();
     }
   }
+  if (connectionEventCallback_) {
+    connectionEventCallback_->onBackoffEnded();
+  }
 }
 
 
index 935e1917049722fea6079d9af4655e2c60be52da..4f1194f7c98b3f4a766e2190c79cf4813149c29d 100644 (file)
@@ -64,6 +64,71 @@ class AsyncServerSocket : public DelayedDestruction
   // Disallow copy, move, and default construction.
   AsyncServerSocket(AsyncServerSocket&&) = delete;
 
+  /**
+   * A callback interface to get notified of client socket events.
+   *
+   * The ConnectionEventCallback implementations need to be thread-safe as the
+   * callbacks may be called from different threads.
+   */
+  class ConnectionEventCallback {
+   public:
+    virtual ~ConnectionEventCallback() = default;
+
+    /**
+     * onConnectionAccepted() is called right after a client connection
+     * is accepted using the system accept()/accept4() APIs.
+     */
+    virtual void onConnectionAccepted(const int socket,
+                                      const SocketAddress& addr) noexcept = 0;
+
+    /**
+     * onConnectionAcceptError() is called when an error occurred accepting
+     * a connection.
+     */
+    virtual void onConnectionAcceptError(const int err) noexcept = 0;
+
+    /**
+     * onConnectionDropped() is called when a connection is dropped,
+     * probably because of some error encountered.
+     */
+    virtual void onConnectionDropped(const int socket,
+                                     const SocketAddress& addr) noexcept = 0;
+
+    /**
+     * onConnectionEnqueuedForAcceptCallback() is called when the
+     * connection is successfully enqueued for an AcceptCallback to pick up.
+     */
+    virtual void onConnectionEnqueuedForAcceptCallback(
+        const int socket,
+        const SocketAddress& addr) noexcept = 0;
+
+    /**
+     * onConnectionDequeuedByAcceptCallback() is called when the
+     * connection is successfully dequeued by an AcceptCallback.
+     */
+    virtual void onConnectionDequeuedByAcceptCallback(
+        const int socket,
+        const SocketAddress& addr) noexcept = 0;
+
+    /**
+     * onBackoffStarted is called when the socket has successfully started
+     * backing off accepting new client sockets.
+     */
+    virtual void onBackoffStarted() noexcept = 0;
+
+    /**
+     * onBackoffEnded is called when the backoff period has ended and the socket
+     * has successfully resumed accepting new connections if there is any
+     * AcceptCallback registered.
+     */
+    virtual void onBackoffEnded() noexcept = 0;
+
+    /**
+     * onBackoffError is called when there is an error entering backoff
+     */
+    virtual void onBackoffError() noexcept = 0;
+  };
+
   class AcceptCallback {
    public:
     virtual ~AcceptCallback() = default;
@@ -320,8 +385,8 @@ class AsyncServerSocket : public DelayedDestruction
    *
    * When a new socket is accepted, one of the AcceptCallbacks will be invoked
    * with the new socket.  The AcceptCallbacks are invoked in a round-robin
-   * fashion.  This allows the accepted sockets to distributed among a pool of
-   * threads, each running its own EventBase object.  This is a common model,
+   * fashion.  This allows the accepted sockets to be distributed among a pool
+   * of threads, each running its own EventBase object.  This is a common model,
    * since most asynchronous-style servers typically run one EventBase thread
    * per CPU.
    *
@@ -584,6 +649,21 @@ class AsyncServerSocket : public DelayedDestruction
     return accepting_;
   }
 
+  /**
+   * Set the ConnectionEventCallback
+   */
+  void setConnectionEventCallback(
+      ConnectionEventCallback* const connectionEventCallback) {
+    connectionEventCallback_ = connectionEventCallback;
+  }
+
+  /**
+   * Get the ConnectionEventCallback
+   */
+  ConnectionEventCallback* getConnectionEventCallback() const {
+    return connectionEventCallback_;
+  }
+
  protected:
   /**
    * Protected destructor.
@@ -618,8 +698,10 @@ class AsyncServerSocket : public DelayedDestruction
   class RemoteAcceptor
       : private NotificationQueue<QueueMessage>::Consumer {
   public:
-    explicit RemoteAcceptor(AcceptCallback *callback)
-      : callback_(callback) {}
+    explicit RemoteAcceptor(AcceptCallback *callback,
+                            ConnectionEventCallback *connectionEventCallback)
+      : callback_(callback),
+        connectionEventCallback_(connectionEventCallback) {}
 
     ~RemoteAcceptor() = default;
 
@@ -634,6 +716,7 @@ class AsyncServerSocket : public DelayedDestruction
 
   private:
     AcceptCallback *callback_;
+    ConnectionEventCallback* connectionEventCallback_;
 
     NotificationQueue<QueueMessage> queue_;
   };
@@ -738,6 +821,7 @@ class AsyncServerSocket : public DelayedDestruction
   bool reusePortEnabled_{false};
   bool closeOnExec_;
   ShutdownSocketSet* shutdownSocketSet_;
+  ConnectionEventCallback* connectionEventCallback_{nullptr};
 };
 
 } // folly
index 74da66215c375acfb149ae9d7447e14bac28d5b7..2c822eef11c5b8e27eec46b73ab2de1e2ae6e84c 100644 (file)
@@ -17,6 +17,7 @@
 #include <folly/io/async/AsyncSocket.h>
 #include <folly/io/async/AsyncTimeout.h>
 #include <folly/io/async/EventBase.h>
+#include <folly/RWSpinLock.h>
 #include <folly/SocketAddress.h>
 
 #include <folly/io/IOBuf.h>
@@ -1452,6 +1453,113 @@ TEST(AsyncSocket, ConnectReadUninstallRead) {
 ///////////////////////////////////////////////////////////////////////////
 // AsyncServerSocket tests
 ///////////////////////////////////////////////////////////////////////////
+namespace {
+/**
+ * Helper ConnectionEventCallback class for the test code.
+ * It maintains counters protected by a spin lock.
+ */
+class TestConnectionEventCallback :
+  public AsyncServerSocket::ConnectionEventCallback {
+ public:
+  virtual void onConnectionAccepted(
+      const int socket,
+      const SocketAddress& addr) noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    connectionAccepted_++;
+  }
+
+  virtual void onConnectionAcceptError(const int err) noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    connectionAcceptedError_++;
+  }
+
+  virtual void onConnectionDropped(
+      const int socket,
+      const SocketAddress& addr) noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    connectionDropped_++;
+  }
+
+  virtual void onConnectionEnqueuedForAcceptCallback(
+      const int socket,
+      const SocketAddress& addr) noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    connectionEnqueuedForAcceptCallback_++;
+  }
+
+  virtual void onConnectionDequeuedByAcceptCallback(
+      const int socket,
+      const SocketAddress& addr) noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    connectionDequeuedByAcceptCallback_++;
+  }
+
+  virtual void onBackoffStarted() noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    backoffStarted_++;
+  }
+
+  virtual void onBackoffEnded() noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    backoffEnded_++;
+  }
+
+  virtual void onBackoffError() noexcept override {
+    folly::RWSpinLock::WriteHolder holder(spinLock_);
+    backoffError_++;
+  }
+
+  unsigned int getConnectionAccepted() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return connectionAccepted_;
+  }
+
+  unsigned int getConnectionAcceptedError() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return connectionAcceptedError_;
+  }
+
+  unsigned int getConnectionDropped() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return connectionDropped_;
+  }
+
+  unsigned int getConnectionEnqueuedForAcceptCallback() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return connectionEnqueuedForAcceptCallback_;
+  }
+
+  unsigned int getConnectionDequeuedByAcceptCallback() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return connectionDequeuedByAcceptCallback_;
+  }
+
+  unsigned int getBackoffStarted() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return backoffStarted_;
+  }
+
+  unsigned int getBackoffEnded() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return backoffEnded_;
+  }
+
+  unsigned int getBackoffError() const {
+    folly::RWSpinLock::ReadHolder holder(spinLock_);
+    return backoffError_;
+  }
+
+ private:
+  mutable folly::RWSpinLock spinLock_;
+  unsigned int connectionAccepted_{0};
+  unsigned int connectionAcceptedError_{0};
+  unsigned int connectionDropped_{0};
+  unsigned int connectionEnqueuedForAcceptCallback_{0};
+  unsigned int connectionDequeuedByAcceptCallback_{0};
+  unsigned int backoffStarted_{0};
+  unsigned int backoffEnded_{0};
+  unsigned int backoffError_{0};
+};
 
 /**
  * Helper AcceptCallback class for the test code
@@ -1552,6 +1660,7 @@ class TestAcceptCallback : public AsyncServerSocket::AcceptCallback {
 
   std::deque<EventInfo> events_;
 };
+}
 
 /**
  * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
@@ -2043,3 +2152,46 @@ TEST(AsyncSocketTest, UnixDomainSocketTest) {
   int flags = fcntl(fd, F_GETFL, 0);
   CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK);
 }
+
+TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
+  EventBase eventBase;
+  TestConnectionEventCallback connectionEventCallback;
+
+  // Create a server socket
+  std::shared_ptr<AsyncServerSocket> serverSocket(
+      AsyncServerSocket::newSocket(&eventBase));
+  serverSocket->setConnectionEventCallback(&connectionEventCallback);
+  serverSocket->bind(0);
+  serverSocket->listen(16);
+  folly::SocketAddress serverAddress;
+  serverSocket->getAddress(&serverAddress);
+
+  // Add a callback to accept one connection then stop the loop
+  TestAcceptCallback acceptCallback;
+  acceptCallback.setConnectionAcceptedFn(
+    [&](int fd, const folly::SocketAddress& addr) {
+      serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+    });
+  acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
+    serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
+  });
+  serverSocket->addAcceptCallback(&acceptCallback, nullptr);
+  serverSocket->startAccepting();
+
+  // Connect to the server socket
+  std::shared_ptr<AsyncSocket> socket(
+      AsyncSocket::newSocket(&eventBase, serverAddress));
+
+  eventBase.loop();
+
+  // Validate the connection event counters
+  ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
+  ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
+  ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
+  ASSERT_EQ(
+      connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1);
+  ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1);
+  ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
+  ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
+  ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
+}