folly: avoid compile warning/failure due to lvalue-to-rvalue conversion
[folly.git] / folly / io / async / AsyncServerSocket.h
index a96f49fe1f7662c20c19a8c406480db9a3fa6971..0a07d546ec274433282a9f4d62f6214c4aac4048 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 
 #pragma once
 
+#include <folly/SocketAddress.h>
+#include <folly/io/ShutdownSocketSet.h>
+#include <folly/io/async/AsyncSocketBase.h>
+#include <folly/io/async/AsyncTimeout.h>
 #include <folly/io/async/DelayedDestruction.h>
-#include <folly/io/async/EventHandler.h>
 #include <folly/io/async/EventBase.h>
+#include <folly/io/async/EventHandler.h>
 #include <folly/io/async/NotificationQueue.h>
-#include <folly/io/async/AsyncTimeout.h>
-#include <folly/io/ShutdownSocketSet.h>
-#include <folly/SocketAddress.h>
-#include <memory>
-#include <exception>
-#include <vector>
+#include <folly/portability/Sockets.h>
+
 #include <limits.h>
 #include <stddef.h>
-#include <sys/socket.h>
-
+#include <exception>
+#include <memory>
+#include <vector>
 
 // Due to the way kernel headers are included, this may or may not be defined.
 // Number pulled from 3.10 kernel headers.
 #define SO_REUSEPORT 15
 #endif
 
+#if defined __linux__ && !defined SO_NO_TRANSPARENT_TLS
+#define SO_NO_TRANSPARENT_TLS 200
+#endif
+
 namespace folly {
 
 /**
@@ -56,13 +61,81 @@ namespace folly {
  * modify the AsyncServerSocket state may only be performed from the primary
  * EventBase thread.
  */
-class AsyncServerSocket : public DelayedDestruction {
+class AsyncServerSocket : public DelayedDestruction
+                        , public AsyncSocketBase {
  public:
   typedef std::unique_ptr<AsyncServerSocket, Destructor> UniquePtr;
+  // Disallow copy, move, and default construction.
+  AsyncServerSocket(AsyncServerSocket&&) = delete;
+
+  /**
+   * A callback interface to get notified of client socket events.
+   *
+   * The ConnectionEventCallback implementations need to be thread-safe as the
+   * callbacks may be called from different threads.
+   */
+  class ConnectionEventCallback {
+   public:
+    virtual ~ConnectionEventCallback() = default;
+
+    /**
+     * onConnectionAccepted() is called right after a client connection
+     * is accepted using the system accept()/accept4() APIs.
+     */
+    virtual void onConnectionAccepted(const int socket,
+                                      const SocketAddress& addr) noexcept = 0;
+
+    /**
+     * onConnectionAcceptError() is called when an error occurred accepting
+     * a connection.
+     */
+    virtual void onConnectionAcceptError(const int err) noexcept = 0;
+
+    /**
+     * onConnectionDropped() is called when a connection is dropped,
+     * probably because of some error encountered.
+     */
+    virtual void onConnectionDropped(const int socket,
+                                     const SocketAddress& addr) noexcept = 0;
+
+    /**
+     * onConnectionEnqueuedForAcceptorCallback() is called when the
+     * connection is successfully enqueued for an AcceptCallback to pick up.
+     */
+    virtual void onConnectionEnqueuedForAcceptorCallback(
+        const int socket,
+        const SocketAddress& addr) noexcept = 0;
+
+    /**
+     * onConnectionDequeuedByAcceptorCallback() is called when the
+     * connection is successfully dequeued by an AcceptCallback.
+     */
+    virtual void onConnectionDequeuedByAcceptorCallback(
+        const int socket,
+        const SocketAddress& addr) noexcept = 0;
+
+    /**
+     * onBackoffStarted is called when the socket has successfully started
+     * backing off accepting new client sockets.
+     */
+    virtual void onBackoffStarted() noexcept = 0;
+
+    /**
+     * onBackoffEnded is called when the backoff period has ended and the socket
+     * has successfully resumed accepting new connections if there is any
+     * AcceptCallback registered.
+     */
+    virtual void onBackoffEnded() noexcept = 0;
+
+    /**
+     * onBackoffError is called when there is an error entering backoff
+     */
+    virtual void onBackoffError() noexcept = 0;
+  };
 
   class AcceptCallback {
    public:
-    virtual ~AcceptCallback() {}
+    virtual ~AcceptCallback() = default;
 
     /**
      * connectionAccepted() is called whenever a new client connection is
@@ -76,7 +149,7 @@ class AsyncServerSocket : public DelayedDestruction {
      *                    for closing it when done.  The newly accepted file
      *                    descriptor will have already been put into
      *                    non-blocking mode.
-     * @param clientAddr  A reference to a TSocketAddress struct containing the
+     * @param clientAddr  A reference to a SocketAddress struct containing the
      *                    client's address.  This struct is only guaranteed to
      *                    remain valid until connectionAccepted() returns.
      */
@@ -131,7 +204,7 @@ class AsyncServerSocket : public DelayedDestruction {
 
   static const uint32_t kDefaultMaxAcceptAtOnce = 30;
   static const uint32_t kDefaultCallbackAcceptAtOnce = 5;
-  static const uint32_t kDefaultMaxMessagesInQueue = 0;
+  static const uint32_t kDefaultMaxMessagesInQueue = 1024;
   /**
    * Create a new AsyncServerSocket with the specified EventBase.
    *
@@ -154,7 +227,7 @@ class AsyncServerSocket : public DelayedDestruction {
                                                  Destructor());
   }
 
-  void setShutdownSocketSet(ShutdownSocketSet* newSS);
+  void setShutdownSocketSet(const std::weak_ptr<ShutdownSocketSet>& wNewSS);
 
   /**
    * Destroy the socket.
@@ -173,7 +246,7 @@ class AsyncServerSocket : public DelayedDestruction {
    * time after destroy() returns.  They will not receive any more callback
    * invocations once acceptStopped() is invoked.
    */
-  virtual void destroy();
+  void destroy() override;
 
   /**
    * Attach this AsyncServerSocket to its primary EventBase.
@@ -195,7 +268,7 @@ class AsyncServerSocket : public DelayedDestruction {
   /**
    * Get the EventBase used by this socket.
    */
-  EventBase* getEventBase() const {
+  EventBase* getEventBase() const override {
     return eventBase_;
   }
 
@@ -246,6 +319,11 @@ class AsyncServerSocket : public DelayedDestruction {
     }
   }
 
+  /* enable zerocopy support for the server sockets - the s = accept sockets
+   * inherit it
+   */
+  bool setZeroCopy(bool enable);
+
   /**
    * Bind to the specified address.
    *
@@ -280,7 +358,18 @@ class AsyncServerSocket : public DelayedDestruction {
    *
    * Throws TTransportException on error.
    */
-  void getAddress(SocketAddress* addressReturn) const;
+  void getAddress(SocketAddress* addressReturn) const override;
+
+  /**
+   * Get the local address to which the socket is bound.
+   *
+   * Throws TTransportException on error.
+   */
+  SocketAddress getAddress() const {
+    SocketAddress ret;
+    getAddress(&ret);
+    return ret;
+  }
 
   /**
    * Get all the local addresses to which the socket is bound.
@@ -316,8 +405,8 @@ class AsyncServerSocket : public DelayedDestruction {
    *
    * When a new socket is accepted, one of the AcceptCallbacks will be invoked
    * with the new socket.  The AcceptCallbacks are invoked in a round-robin
-   * fashion.  This allows the accepted sockets to distributed among a pool of
-   * threads, each running its own EventBase object.  This is a common model,
+   * fashion.  This allows the accepted sockets to be distributed among a pool
+   * of threads, each running its own EventBase object.  This is a common model,
    * since most asynchronous-style servers typically run one EventBase thread
    * per CPU.
    *
@@ -495,6 +584,25 @@ class AsyncServerSocket : public DelayedDestruction {
     return numDroppedConnections_;
   }
 
+  /**
+   * Get the current number of unprocessed messages in NotificationQueue.
+   *
+   * This method must be invoked from the AsyncServerSocket's primary
+   * EventBase thread.  Use EventBase::runInEventBaseThread() to schedule the
+   * operation in the correct EventBase if your code is not in the server
+   * socket's primary EventBase.
+   */
+  int64_t getNumPendingMessagesInQueue() const {
+    if (eventBase_) {
+      eventBase_->dcheckIsInEventBaseThread();
+    }
+    int64_t numMsgs = 0;
+    for (const auto& callback : callbacks_) {
+      numMsgs += callback.consumer->getQueue()->size();
+    }
+    return numMsgs;
+  }
+
   /**
    * Set whether or not SO_KEEPALIVE should be enabled on the server socket
    * (and thus on all subsequently-accepted connections). By default, keepalive
@@ -573,13 +681,50 @@ class AsyncServerSocket : public DelayedDestruction {
     return closeOnExec_;
   }
 
+  /**
+   * Tries to enable TFO if the machine supports it.
+   */
+  void setTFOEnabled(bool enabled, uint32_t maxTFOQueueSize) {
+    tfo_ = enabled;
+    tfoMaxQueueSize_ = maxTFOQueueSize;
+  }
+
+  /**
+   * Do not attempt the transparent TLS handshake
+   */
+  void disableTransparentTls() {
+    noTransparentTls_ = true;
+  }
+
+  /**
+   * Get whether or not the socket is accepting new connections
+   */
+  bool getAccepting() const {
+    return accepting_;
+  }
+
+  /**
+   * Set the ConnectionEventCallback
+   */
+  void setConnectionEventCallback(
+      ConnectionEventCallback* const connectionEventCallback) {
+    connectionEventCallback_ = connectionEventCallback;
+  }
+
+  /**
+   * Get the ConnectionEventCallback
+   */
+  ConnectionEventCallback* getConnectionEventCallback() const {
+    return connectionEventCallback_;
+  }
+
  protected:
   /**
    * Protected destructor.
    *
    * Invoke destroy() instead to destroy the AsyncServerSocket.
    */
-  virtual ~AsyncServerSocket();
+  ~AsyncServerSocket() override;
 
  private:
   enum class MessageType {
@@ -606,23 +751,26 @@ class AsyncServerSocket : public DelayedDestruction {
    */
   class RemoteAcceptor
       : private NotificationQueue<QueueMessage>::Consumer {
-  public:
-    explicit RemoteAcceptor(AcceptCallback *callback)
-      : callback_(callback) {}
+   public:
+    explicit RemoteAcceptor(AcceptCallback *callback,
+                            ConnectionEventCallback *connectionEventCallback)
+      : callback_(callback),
+        connectionEventCallback_(connectionEventCallback) {}
 
-    ~RemoteAcceptor() {}
+    ~RemoteAcceptor() override = default;
 
     void start(EventBase *eventBase, uint32_t maxAtOnce, uint32_t maxInQueue);
     void stop(EventBase* eventBase, AcceptCallback* callback);
 
-    virtual void messageAvailable(QueueMessage&& message);
+    void messageAvailable(QueueMessage&& message) noexcept override;
 
     NotificationQueue<QueueMessage>* getQueue() {
       return &queue_;
     }
 
-  private:
+   private:
     AcceptCallback *callback_;
+    ConnectionEventCallback* connectionEventCallback_;
 
     NotificationQueue<QueueMessage> queue_;
   };
@@ -649,7 +797,7 @@ class AsyncServerSocket : public DelayedDestruction {
     uint16_t events, int socket, sa_family_t family) noexcept;
 
   int createSocket(int family);
-  void setupSocket(int fd);
+  void setupSocket(int fd, int family);
   void bindSocket(int fd, const SocketAddress& address, bool isExistingSocket);
   void dispatchSocket(int socket, SocketAddress&& address);
   void dispatchError(const char *msg, int errnoValue);
@@ -700,7 +848,7 @@ class AsyncServerSocket : public DelayedDestruction {
     }
 
     // Inherited from EventHandler
-    virtual void handlerReady(uint16_t events) noexcept {
+    void handlerReady(uint16_t events) noexcept override {
       parent_->handlerReady(events, socket_, addressFamily_);
     }
 
@@ -726,7 +874,11 @@ class AsyncServerSocket : public DelayedDestruction {
   bool keepAliveEnabled_;
   bool reusePortEnabled_{false};
   bool closeOnExec_;
-  ShutdownSocketSet* shutdownSocketSet_;
+  bool tfo_{false};
+  bool noTransparentTls_{false};
+  uint32_t tfoMaxQueueSize_{0};
+  std::weak_ptr<ShutdownSocketSet> wShutdownSocketSet_;
+  ConnectionEventCallback* connectionEventCallback_{nullptr};
 };
 
-} // folly
+} // namespace folly