Udp Acceptor
authorDave Watson <davejwatson@fb.com>
Fri, 20 Mar 2015 16:54:08 +0000 (09:54 -0700)
committerNoam Lerner <noamler@fb.com>
Wed, 25 Mar 2015 22:34:48 +0000 (15:34 -0700)
Summary:
major changes:

1) ServerSocketFactory and AsyncSocketBase to abstract the differences between UDP and TCP async socket.  Could possibly push some of this to the sockets themselves eventually
2) pipeline() is a pipeline between accept/receive of a UDP message, and before sending it to workers.  Default impl for TCP is to fan out to worker threads.  This is the same as Netty.  Since we don't know if the data is a TCP socket or a UDP message, it's a void*, which sucks (netty uses Object msg, so it isn't any different).

Test Plan: Added lots of new tests.  Doesn't test any data passing yet though, just connects/simple receipt of UDP message.

Reviewed By: hans@fb.com

Subscribers: alandau, bmatheny, mshneer, jsedgwick, yfeldblum, trunkagent, doug, fugalh, folly-diffs@

FB internal diff: D1736670

Tasks: 5788116

Signature: t1:1736670:1424372992:e109450604ed905004bd40dfbb508b5808332c15

12 files changed:
folly/Makefile.am
folly/io/async/AsyncServerSocket.h
folly/io/async/AsyncSocketBase.h [new file with mode: 0644]
folly/io/async/AsyncTransport.h
folly/io/async/AsyncUDPServerSocket.h
folly/io/async/AsyncUDPSocket.h
folly/wangle/acceptor/Acceptor.h
folly/wangle/bootstrap/BootstrapTest.cpp
folly/wangle/bootstrap/ServerBootstrap-inl.h
folly/wangle/bootstrap/ServerBootstrap.cpp
folly/wangle/bootstrap/ServerBootstrap.h
folly/wangle/bootstrap/ServerSocketFactory.h [new file with mode: 0644]

index 72272940a41f1d4fa0d1aa1315ea49654164080e..a5aec762e285bb8c3e2cfae65165208e8ec95ccc 100644 (file)
@@ -148,6 +148,7 @@ nobase_follyinclude_HEADERS = \
        io/async/AsyncUDPSocket.h \
        io/async/AsyncServerSocket.h \
        io/async/AsyncSocket.h \
+       io/async/AsyncSocketBase.h \
        io/async/AsyncSSLSocket.h \
        io/async/AsyncSocketException.h \
        io/async/DelayedDestruction.h \
@@ -234,6 +235,7 @@ nobase_follyinclude_HEADERS = \
        wangle/acceptor/TransportInfo.h \
        wangle/bootstrap/ServerBootstrap.h \
        wangle/bootstrap/ServerBootstrap-inl.h \
+       wangle/bootstrap/ServerSocketFactory.h \
        wangle/bootstrap/ClientBootstrap.h \
        wangle/channel/AsyncSocketHandler.h \
        wangle/channel/ChannelHandler.h \
index 46576aa255556dc97efa5f12cb005833b4f4b277..ff9562fe0c4b3302c58832bc0868a1a8d0f5fe3f 100644 (file)
@@ -21,6 +21,7 @@
 #include <folly/io/async/EventBase.h>
 #include <folly/io/async/NotificationQueue.h>
 #include <folly/io/async/AsyncTimeout.h>
+#include <folly/io/async/AsyncSocketBase.h>
 #include <folly/io/ShutdownSocketSet.h>
 #include <folly/SocketAddress.h>
 #include <memory>
@@ -56,7 +57,8 @@ namespace folly {
  * modify the AsyncServerSocket state may only be performed from the primary
  * EventBase thread.
  */
-class AsyncServerSocket : public DelayedDestruction {
+class AsyncServerSocket : public DelayedDestruction
+                        , public AsyncSocketBase {
  public:
   typedef std::unique_ptr<AsyncServerSocket, Destructor> UniquePtr;
   // Disallow copy, move, and default construction.
diff --git a/folly/io/async/AsyncSocketBase.h b/folly/io/async/AsyncSocketBase.h
new file mode 100644 (file)
index 0000000..35e8a15
--- /dev/null
@@ -0,0 +1,30 @@
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/SocketAddress.h>
+#include <folly/io/async/EventBase.h>
+
+namespace folly {
+
+class AsyncSocketBase {
+ public:
+  virtual EventBase* getEventBase() const = 0;
+  virtual ~AsyncSocketBase() = default;
+  virtual void getAddress(SocketAddress*) const = 0;
+};
+
+} // namespace
index 95fdd9ddd08d9fc79ce294f12bc7fc97d1920cfe..1cf4d209d91eae5bc2cb17f1c41c15101ec617ac 100644 (file)
@@ -20,6 +20,8 @@
 #include <sys/uio.h>
 
 #include <folly/io/async/DelayedDestruction.h>
+#include <folly/io/async/EventBase.h>
+#include <folly/io/async/AsyncSocketBase.h>
 
 namespace folly {
 
@@ -111,7 +113,7 @@ inline bool isSet(WriteFlags a, WriteFlags b) {
  * timeout, since most callers want to give up if the remote end stops
  * responding and no further progress can be made sending the data.
  */
-class AsyncTransport : public DelayedDestruction {
+class AsyncTransport : public DelayedDestruction, public AsyncSocketBase {
  public:
   typedef std::unique_ptr<AsyncTransport, Destructor> UniquePtr;
 
@@ -256,14 +258,6 @@ class AsyncTransport : public DelayedDestruction {
    */
   virtual bool isDetachable() const = 0;
 
-  /**
-   * Get the EventBase used by this transport.
-   *
-   * Returns nullptr if this transport is not currently attached to a
-   * EventBase.
-   */
-  virtual EventBase* getEventBase() const = 0;
-
   /**
    * Set the send timeout.
    *
@@ -296,6 +290,10 @@ class AsyncTransport : public DelayedDestruction {
    */
   virtual void getLocalAddress(SocketAddress* address) const = 0;
 
+  virtual void getAddress(SocketAddress* address) const {
+    getLocalAddress(address);
+  }
+
   /**
    * Get the address of the remote endpoint to which this transport is
    * connected.
index 9feb6d41f2ee45d486ba8dec58a0ce2531805fb2..e424e83d9bd4ab35ce0b9752b57cd45e37962129 100644 (file)
@@ -36,7 +36,8 @@ namespace folly {
  *       more than 1 packet will not work because they will end up with
  *       different event base to process.
  */
-class AsyncUDPServerSocket : private AsyncUDPSocket::ReadCallback {
+class AsyncUDPServerSocket : private AsyncUDPSocket::ReadCallback
+                           , public AsyncSocketBase {
  public:
   class Callback {
    public:
@@ -93,6 +94,10 @@ class AsyncUDPServerSocket : private AsyncUDPSocket::ReadCallback {
     return socket_->address();
   }
 
+  void getAddress(SocketAddress* a) const {
+    *a = address();
+  }
+
   /**
    * Add a listener to the round robin list
    */
@@ -124,6 +129,10 @@ class AsyncUDPServerSocket : private AsyncUDPSocket::ReadCallback {
     socket_.reset();
   }
 
+  EventBase* getEventBase() const {
+    return evb_;
+  }
+
  private:
   // AsyncUDPSocket::ReadCallback
   void getReadBuffer(void** buf, size_t* len) noexcept {
index 0341b00f963b7705d0e27e5a548f03d81fcf0fe8..a1bca318f07ef5b87b21c0b5ae46b03564491496 100644 (file)
@@ -19,6 +19,7 @@
 #include <folly/io/IOBuf.h>
 #include <folly/ScopeGuard.h>
 #include <folly/io/async/AsyncSocketException.h>
+#include <folly/io/async/AsyncSocketBase.h>
 #include <folly/io/async/EventHandler.h>
 #include <folly/io/async/EventBase.h>
 #include <folly/SocketAddress.h>
index 7d7269c502156bbb343c92df98a05450820039c1..c82f97a6755f09f0933ee7180a54a055918724dc 100644 (file)
@@ -20,6 +20,7 @@
 #include <event.h>
 #include <folly/io/async/AsyncSSLSocket.h>
 #include <folly/io/async/AsyncServerSocket.h>
+#include <folly/io/async/AsyncUDPServerSocket.h>
 
 namespace folly { namespace wangle {
 class ManagedConnection;
@@ -46,7 +47,8 @@ class SSLContextManager;
  */
 class Acceptor :
   public folly::AsyncServerSocket::AcceptCallback,
-  public folly::wangle::ConnectionManager::Callback {
+  public folly::wangle::ConnectionManager::Callback,
+  public AsyncUDPServerSocket::Callback  {
  public:
 
   enum class State : uint32_t {
@@ -229,6 +231,10 @@ class Acceptor :
       const std::string& nextProtocolName,
       const TransportInfo& tinfo) = 0;
 
+  void onListenStarted() noexcept {}
+  void onListenStopped() noexcept {}
+  void onDataAvailable(const SocketAddress&, std::unique_ptr<IOBuf>, bool) noexcept {}
+
   virtual AsyncSocket::UniquePtr makeNewAsyncSocket(EventBase* base, int fd) {
     return AsyncSocket::UniquePtr(new AsyncSocket(base, fd));
   }
index 4ba1880d0f6eb3f88b7ac11738a5b980f0f321a1..d0ee7e2a4a0ca336aa98690bb34db09a884435db 100644 (file)
@@ -52,6 +52,27 @@ class TestPipelineFactory : public PipelineFactory<Pipeline> {
   std::atomic<int> pipelines{0};
 };
 
+class TestAcceptor : public Acceptor {
+EventBase base_;
+ public:
+  TestAcceptor() : Acceptor(ServerSocketConfig()) {
+    Acceptor::init(nullptr, &base_);
+  }
+  void onNewConnection(
+      AsyncSocket::UniquePtr sock,
+      const folly::SocketAddress* address,
+      const std::string& nextProtocolName,
+        const TransportInfo& tinfo) {
+  }
+};
+
+class TestAcceptorFactory : public AcceptorFactory {
+ public:
+  std::shared_ptr<Acceptor> newAcceptor(EventBase* base) {
+    return std::make_shared<TestAcceptor>();
+  }
+};
+
 TEST(Bootstrap, Basic) {
   TestServer server;
   TestClient client;
@@ -64,6 +85,13 @@ TEST(Bootstrap, ServerWithPipeline) {
   server.stop();
 }
 
+TEST(Bootstrap, ServerWithChildHandler) {
+  TestServer server;
+  server.childHandler(std::make_shared<TestAcceptorFactory>());
+  server.bind(0);
+  server.stop();
+}
+
 TEST(Bootstrap, ClientServerTest) {
   TestServer server;
   auto factory = std::make_shared<TestPipelineFactory>();
@@ -236,3 +264,107 @@ TEST(Bootstrap, ExistingSocket) {
   folly::AsyncServerSocket::UniquePtr socket(new AsyncServerSocket);
   server.bind(std::move(socket));
 }
+
+std::atomic<int> connections{0};
+
+class TestHandlerPipeline
+    : public ChannelHandlerAdapter<void*,
+                                   std::exception> {
+ public:
+  void read(Context* ctx, void* conn) {
+    connections++;
+    return ctx->fireRead(conn);
+  }
+
+  Future<void> write(Context* ctx, std::exception e) {
+    return ctx->fireWrite(e);
+  }
+};
+
+template <typename HandlerPipeline>
+class TestHandlerPipelineFactory
+    : public PipelineFactory<ServerBootstrap<Pipeline>::AcceptPipeline> {
+ public:
+  ServerBootstrap<Pipeline>::AcceptPipeline* newPipeline(std::shared_ptr<AsyncSocket>) {
+    auto pipeline = new ServerBootstrap<Pipeline>::AcceptPipeline;
+    auto handler = std::make_shared<HandlerPipeline>();
+      pipeline->addBack(ChannelHandlerPtr<HandlerPipeline>(handler));
+    return pipeline;
+  }
+};
+
+TEST(Bootstrap, LoadBalanceHandler) {
+  TestServer server;
+  auto factory = std::make_shared<TestPipelineFactory>();
+  server.childPipeline(factory);
+
+  auto pipelinefactory =
+    std::make_shared<TestHandlerPipelineFactory<TestHandlerPipeline>>();
+  server.pipeline(pipelinefactory);
+  server.bind(0);
+  auto base = EventBaseManager::get()->getEventBase();
+
+  SocketAddress address;
+  server.getSockets()[0]->getAddress(&address);
+
+  TestClient client;
+  client.pipelineFactory(std::make_shared<TestClientPipelineFactory>());
+  client.connect(address);
+  base->loop();
+  server.stop();
+
+  CHECK(factory->pipelines == 1);
+  CHECK(connections == 1);
+}
+
+class TestUDPPipeline
+    : public ChannelHandlerAdapter<void*,
+                                   std::exception> {
+ public:
+  void read(Context* ctx, void* conn) {
+    connections++;
+  }
+
+  Future<void> write(Context* ctx, std::exception e) {
+    return ctx->fireWrite(e);
+  }
+};
+
+TEST(Bootstrap, UDP) {
+  TestServer server;
+  auto factory = std::make_shared<TestPipelineFactory>();
+  auto pipelinefactory =
+    std::make_shared<TestHandlerPipelineFactory<TestUDPPipeline>>();
+  server.pipeline(pipelinefactory);
+  server.channelFactory(std::make_shared<AsyncUDPServerSocketFactory>());
+  server.bind(0);
+}
+
+TEST(Bootstrap, UDPClientServerTest) {
+  connections = 0;
+
+  TestServer server;
+  auto factory = std::make_shared<TestPipelineFactory>();
+  auto pipelinefactory =
+    std::make_shared<TestHandlerPipelineFactory<TestUDPPipeline>>();
+  server.pipeline(pipelinefactory);
+  server.channelFactory(std::make_shared<AsyncUDPServerSocketFactory>());
+  server.bind(0);
+
+  auto base = EventBaseManager::get()->getEventBase();
+
+  SocketAddress address;
+  server.getSockets()[0]->getAddress(&address);
+
+  SocketAddress localhost("::1", 0);
+  AsyncUDPSocket client(base);
+  client.bind(localhost);
+  auto data = IOBuf::create(1);
+  data->append(1);
+  *(data->writableData()) = 'a';
+  client.write(address, std::move(data));
+  base->loop();
+  server.stop();
+
+  CHECK(connections == 1);
+}
index ac18f314bd946cf50ce4a88715bc8bf279b33983..a6ffd26884b94f79085b9405ee90ba1aefed2a98 100644 (file)
 #pragma once
 
 #include <folly/wangle/acceptor/Acceptor.h>
+#include <folly/wangle/bootstrap/ServerSocketFactory.h>
 #include <folly/io/async/EventBaseManager.h>
 #include <folly/wangle/concurrent/IOThreadPoolExecutor.h>
 #include <folly/wangle/acceptor/ManagedConnection.h>
 #include <folly/wangle/channel/ChannelPipeline.h>
+#include <folly/wangle/channel/ChannelHandler.h>
 
 namespace folly {
 
 template <typename Pipeline>
-class ServerAcceptor : public Acceptor {
+class ServerAcceptor
+    : public Acceptor
+    , public folly::wangle::ChannelHandlerAdapter<void*, std::exception> {
   typedef std::unique_ptr<Pipeline,
                           folly::DelayedDestruction::Destructor> PipelinePtr;
 
@@ -55,21 +59,26 @@ class ServerAcceptor : public Acceptor {
 
  public:
   explicit ServerAcceptor(
-    std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory,
-    EventBase* base)
+        std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory,
+        std::shared_ptr<folly::wangle::ChannelPipeline<
+                          void*, std::exception>> acceptorPipeline,
+        EventBase* base)
       : Acceptor(ServerSocketConfig())
-      , pipelineFactory_(pipelineFactory) {
-    Acceptor::init(nullptr, base);
+      , base_(base)
+      , childPipelineFactory_(pipelineFactory)
+      , acceptorPipeline_(acceptorPipeline) {
+    Acceptor::init(nullptr, base_);
+    CHECK(acceptorPipeline_);
+
+    acceptorPipeline_->addBack(folly::wangle::ChannelHandlerPtr<ServerAcceptor, false>(this));
+    acceptorPipeline_->finalize();
   }
 
-  /* See Acceptor::onNewConnection for details */
-  void onNewConnection(
-    AsyncSocket::UniquePtr transport, const SocketAddress* address,
-    const std::string& nextProtocolName, const TransportInfo& tinfo) {
-
+  void read(Context* ctx, void* conn) {
+    AsyncSocket::UniquePtr transport((AsyncSocket*)conn);
       std::unique_ptr<Pipeline,
                        folly::DelayedDestruction::Destructor>
-      pipeline(pipelineFactory_->newPipeline(
+      pipeline(childPipelineFactory_->newPipeline(
         std::shared_ptr<AsyncSocket>(
           transport.release(),
           folly::DelayedDestruction::Destructor())));
@@ -77,22 +86,53 @@ class ServerAcceptor : public Acceptor {
     Acceptor::addConnection(connection);
   }
 
+  folly::Future<void> write(Context* ctx, std::exception e) {
+    return ctx->fireWrite(e);
+  }
+
+  /* See Acceptor::onNewConnection for details */
+  void onNewConnection(
+    AsyncSocket::UniquePtr transport, const SocketAddress* address,
+    const std::string& nextProtocolName, const TransportInfo& tinfo) {
+    acceptorPipeline_->read(transport.release());
+  }
+
+  // UDP thunk
+  void onDataAvailable(const folly::SocketAddress& addr,
+                       std::unique_ptr<folly::IOBuf> buf,
+                       bool truncated) noexcept {
+    acceptorPipeline_->read(buf.release());
+  }
+
  private:
-  std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory_;
+  EventBase* base_;
+
+  std::shared_ptr<PipelineFactory<Pipeline>> childPipelineFactory_;
+  std::shared_ptr<folly::wangle::ChannelPipeline<
+    void*, std::exception>> acceptorPipeline_;
 };
 
 template <typename Pipeline>
 class ServerAcceptorFactory : public AcceptorFactory {
  public:
   explicit ServerAcceptorFactory(
-      std::shared_ptr<PipelineFactory<Pipeline>> factory)
-    : factory_(factory) {}
-
-  std::shared_ptr<Acceptor> newAcceptor(folly::EventBase* base) {
-    return std::make_shared<ServerAcceptor<Pipeline>>(factory_, base);
+    std::shared_ptr<PipelineFactory<Pipeline>> factory,
+    std::shared_ptr<PipelineFactory<folly::wangle::ChannelPipeline<
+    void*, std::exception>>> pipeline)
+    : factory_(factory)
+    , pipeline_(pipeline) {}
+
+  std::shared_ptr<Acceptor> newAcceptor(EventBase* base) {
+    std::shared_ptr<folly::wangle::ChannelPipeline<
+                      void*, std::exception>> pipeline(
+                        pipeline_->newPipeline(nullptr));
+    return std::make_shared<ServerAcceptor<Pipeline>>(factory_, pipeline, base);
   }
  private:
   std::shared_ptr<PipelineFactory<Pipeline>> factory_;
+  std::shared_ptr<PipelineFactory<
+    folly::wangle::ChannelPipeline<
+      void*, std::exception>>> pipeline_;
 };
 
 class ServerWorkerPool : public folly::wangle::ThreadPoolExecutor::Observer {
@@ -100,10 +140,12 @@ class ServerWorkerPool : public folly::wangle::ThreadPoolExecutor::Observer {
   explicit ServerWorkerPool(
     std::shared_ptr<AcceptorFactory> acceptorFactory,
     folly::wangle::IOThreadPoolExecutor* exec,
-    std::vector<std::shared_ptr<folly::AsyncServerSocket>>* sockets)
+    std::vector<std::shared_ptr<folly::AsyncSocketBase>>* sockets,
+    std::shared_ptr<ServerSocketFactory> socketFactory)
       : acceptorFactory_(acceptorFactory)
       , exec_(exec)
-      , sockets_(sockets) {
+      , sockets_(sockets)
+      , socketFactory_(socketFactory) {
     CHECK(exec);
   }
 
@@ -128,7 +170,8 @@ class ServerWorkerPool : public folly::wangle::ThreadPoolExecutor::Observer {
            std::shared_ptr<Acceptor>> workers_;
   std::shared_ptr<AcceptorFactory> acceptorFactory_;
   folly::wangle::IOThreadPoolExecutor* exec_{nullptr};
-  std::vector<std::shared_ptr<folly::AsyncServerSocket>>* sockets_;
+  std::vector<std::shared_ptr<folly::AsyncSocketBase>>* sockets_;
+  std::shared_ptr<ServerSocketFactory> socketFactory_;
 };
 
 template <typename F>
@@ -138,4 +181,16 @@ void ServerWorkerPool::forEachWorker(F&& f) const {
   }
 }
 
+class DefaultAcceptPipelineFactory
+    : public PipelineFactory<wangle::ChannelPipeline<void*, std::exception>> {
+  typedef wangle::ChannelPipeline<
+      void*,
+      std::exception> AcceptPipeline;
+
+ public:
+  AcceptPipeline* newPipeline(std::shared_ptr<AsyncSocket>) {
+    return new AcceptPipeline;
+  }
+};
+
 } // namespace
index be2add8d4a3b83af8610b4568d9ae363cbc06da5..cd7a88eb950f1dd803248973917ee020e600b953 100644 (file)
@@ -15,6 +15,7 @@
  */
 #include <folly/wangle/bootstrap/ServerBootstrap.h>
 #include <folly/wangle/concurrent/NamedThreadFactory.h>
+#include <folly/wangle/channel/ChannelHandler.h>
 #include <folly/io/async/EventBaseManager.h>
 
 namespace folly {
@@ -25,8 +26,9 @@ void ServerWorkerPool::threadStarted(
   workers_.insert({h, worker});
 
   for(auto socket : *sockets_) {
-    socket->getEventBase()->runInEventBaseThread([this, worker, socket](){
-      socket->addAcceptCallback(worker.get(), worker->getEventBase());
+    socket->getEventBase()->runInEventBaseThreadAndWait([this, worker, socket](){
+        socketFactory_->addAcceptCB(
+          socket, worker.get(), worker->getEventBase());
     });
   }
 }
@@ -38,22 +40,22 @@ void ServerWorkerPool::threadStopped(
 
   for (auto& socket : *sockets_) {
     folly::Baton<> barrier;
-    socket->getEventBase()->runInEventBaseThread([&]() {
-      socket->removeAcceptCallback(worker->second.get(), nullptr);
+    socket->getEventBase()->runInEventBaseThreadAndWait([&]() {
+      socketFactory_->removeAcceptCB(
+        socket, worker->second.get(), nullptr);
       barrier.post();
     });
     barrier.wait();
   }
 
-  CHECK(worker->second->getEventBase() != nullptr);
-  CHECK(!worker->second->getEventBase()->isInEventBaseThread());
-  folly::Baton<> barrier;
-  worker->second->getEventBase()->runInEventBaseThread([&]() {
-      worker->second->dropAllConnections();
-      barrier.post();
-  });
+  if (!worker->second->getEventBase()->isInEventBaseThread()) {
+    worker->second->getEventBase()->runInEventBaseThreadAndWait([=]() {
+        worker->second->dropAllConnections();
+      });
+  } else {
+    worker->second->dropAllConnections();
+  }
 
-  barrier.wait();
   workers_.erase(worker);
 }
 
index 5a65186da013e9e4c3b4e7e2ff15198850f0816a..82465988e7660b763706f1ce7eb02046d79cd266 100644 (file)
@@ -17,6 +17,7 @@
 
 #include <folly/wangle/bootstrap/ServerBootstrap-inl.h>
 #include <folly/Baton.h>
+#include <folly/wangle/channel/ChannelPipeline.h>
 
 namespace folly {
 
@@ -44,16 +45,24 @@ class ServerBootstrap {
   ~ServerBootstrap() {
     stop();
   }
-  /* TODO(davejwatson)
-   *
-   * If there is any work to be done BEFORE handing the work to IO
-   * threads, this handler is where the pipeline to do it would be
-   * set.
-   *
-   * This could be used for things like logging, load balancing, or
-   * advanced load balancing on IO threads.  Netty also provides this.
+
+  typedef wangle::ChannelPipeline<
+   void*,
+   std::exception> AcceptPipeline;
+  /*
+   * Pipeline used to add connections to event bases.
+   * This is used for UDP or for load balancing
+   * TCP connections to IO threads explicitly
    */
-  ServerBootstrap* handler() {
+  ServerBootstrap* pipeline(
+    std::shared_ptr<PipelineFactory<AcceptPipeline>> factory) {
+    pipeline_ = factory;
+    return this;
+  }
+
+  ServerBootstrap* channelFactory(
+    std::shared_ptr<ServerSocketFactory> factory) {
+    socketFactory_ = factory;
     return this;
   }
 
@@ -75,7 +84,7 @@ class ServerBootstrap {
    */
   ServerBootstrap* childPipeline(
       std::shared_ptr<PipelineFactory<Pipeline>> factory) {
-    pipelineFactory_ = factory;
+    childPipelineFactory_ = factory;
     return this;
   }
 
@@ -111,15 +120,19 @@ class ServerBootstrap {
         32, std::make_shared<wangle::NamedThreadFactory>("IO Thread"));
     }
 
-    CHECK(acceptorFactory_ || pipelineFactory_);
+    // TODO better config checking
+    // CHECK(acceptorFactory_ || childPipelineFactory_);
+    CHECK(!(acceptorFactory_ && childPipelineFactory_));
 
     if (acceptorFactory_) {
       workerFactory_ = std::make_shared<ServerWorkerPool>(
-        acceptorFactory_, io_group.get(), &sockets_);
+        acceptorFactory_, io_group.get(), &sockets_, socketFactory_);
     } else {
       workerFactory_ = std::make_shared<ServerWorkerPool>(
-        std::make_shared<ServerAcceptorFactory<Pipeline>>(pipelineFactory_),
-        io_group.get(), &sockets_);
+        std::make_shared<ServerAcceptorFactory<Pipeline>>(
+          childPipelineFactory_,
+          pipeline_),
+        io_group.get(), &sockets_, socketFactory_);
     }
 
     io_group->addObserver(workerFactory_);
@@ -143,13 +156,14 @@ class ServerBootstrap {
     // Since only a single socket is given,
     // we can only accept on a single thread
     CHECK(acceptor_group_->numThreads() == 1);
+
     std::shared_ptr<folly::AsyncServerSocket> socket(
       s.release(), DelayedDestruction::Destructor());
 
     folly::Baton<> barrier;
     acceptor_group_->add([&](){
       socket->attachEventBase(EventBaseManager::get()->getEventBase());
-      socket->listen(1024);
+      socket->listen(socketConfig.acceptBacklog);
       socket->startAccepting();
       barrier.post();
     });
@@ -157,8 +171,9 @@ class ServerBootstrap {
 
     // Startup all the threads
     workerFactory_->forEachWorker([this, socket](Acceptor* worker){
-      socket->getEventBase()->runInEventBaseThread([this, worker, socket](){
-        socket->addAcceptCallback(worker, worker->getEventBase());
+      socket->getEventBase()->runInEventBaseThreadAndWait(
+        [this, worker, socket](){
+          socketFactory_->addAcceptCB(socket, worker, worker->getEventBase());
       });
     });
 
@@ -192,31 +207,16 @@ class ServerBootstrap {
     }
 
     std::mutex sock_lock;
-    std::vector<std::shared_ptr<folly::AsyncServerSocket>> new_sockets;
+    std::vector<std::shared_ptr<folly::AsyncSocketBase>> new_sockets;
+
 
     std::exception_ptr exn;
 
     auto startupFunc = [&](std::shared_ptr<folly::Baton<>> barrier){
-        auto socket = folly::AsyncServerSocket::newSocket();
-        socket->setReusePortEnabled(reusePort);
-        socket->attachEventBase(EventBaseManager::get()->getEventBase());
-
-        try {
-          if (port >= 0) {
-            socket->bind(port);
-          } else {
-            socket->bind(address);
-            port = address.getPort();
-          }
-
-          socket->listen(socketConfig.acceptBacklog);
-          socket->startAccepting();
-        } catch (...) {
-          exn = std::current_exception();
-          barrier->post();
-
-          return;
-        }
+
+      try {
+        auto socket = socketFactory_->newSocket(
+          port, address, socketConfig.acceptBacklog, reusePort, socketConfig);
 
         sock_lock.lock();
         new_sockets.push_back(socket);
@@ -228,6 +228,15 @@ class ServerBootstrap {
         }
 
         barrier->post();
+      } catch (...) {
+        exn = std::current_exception();
+        barrier->post();
+
+        return;
+      }
+
+
+
     };
 
     auto wait0 = std::make_shared<folly::Baton<>>();
@@ -244,16 +253,14 @@ class ServerBootstrap {
       std::rethrow_exception(exn);
     }
 
-    // Startup all the threads
-    for(auto socket : new_sockets) {
+    for (auto& socket : new_sockets) {
+      // Startup all the threads
       workerFactory_->forEachWorker([this, socket](Acceptor* worker){
-        socket->getEventBase()->runInEventBaseThread([this, worker, socket](){
-          socket->addAcceptCallback(worker, worker->getEventBase());
+        socket->getEventBase()->runInEventBaseThreadAndWait([this, worker, socket](){
+          socketFactory_->addAcceptCB(socket, worker, worker->getEventBase());
         });
       });
-    }
 
-    for (auto& socket : new_sockets) {
       sockets_.push_back(socket);
     }
   }
@@ -264,9 +271,8 @@ class ServerBootstrap {
   void stop() {
     for (auto socket : sockets_) {
       folly::Baton<> barrier;
-      socket->getEventBase()->runInEventBaseThread([&barrier, socket]() {
-        socket->stopAccepting();
-        socket->detachEventBase();
+      socket->getEventBase()->runInEventBaseThread([&]() mutable {
+        socketFactory_->stopSocket(socket);
         barrier.post();
       });
       barrier.wait();
@@ -284,7 +290,7 @@ class ServerBootstrap {
   /*
    * Get the list of listening sockets
    */
-  const std::vector<std::shared_ptr<folly::AsyncServerSocket>>&
+  const std::vector<std::shared_ptr<folly::AsyncSocketBase>>&
   getSockets() const {
     return sockets_;
   }
@@ -305,10 +311,14 @@ class ServerBootstrap {
   std::shared_ptr<wangle::IOThreadPoolExecutor> io_group_;
 
   std::shared_ptr<ServerWorkerPool> workerFactory_;
-  std::vector<std::shared_ptr<folly::AsyncServerSocket>> sockets_;
+  std::vector<std::shared_ptr<folly::AsyncSocketBase>> sockets_;
 
   std::shared_ptr<AcceptorFactory> acceptorFactory_;
-  std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory_;
+  std::shared_ptr<PipelineFactory<Pipeline>> childPipelineFactory_;
+  std::shared_ptr<PipelineFactory<AcceptPipeline>> pipeline_{
+    std::make_shared<DefaultAcceptPipelineFactory>()};
+  std::shared_ptr<ServerSocketFactory> socketFactory_{
+    std::make_shared<AsyncServerSocketFactory>()};
 };
 
 } // namespace
diff --git a/folly/wangle/bootstrap/ServerSocketFactory.h b/folly/wangle/bootstrap/ServerSocketFactory.h
new file mode 100644 (file)
index 0000000..ca99c2e
--- /dev/null
@@ -0,0 +1,118 @@
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/wangle/bootstrap/ServerBootstrap-inl.h>
+#include <folly/io/async/AsyncServerSocket.h>
+#include <folly/io/async/EventBaseManager.h>
+#include <folly/io/async/AsyncUDPServerSocket.h>
+
+namespace folly {
+
+class ServerSocketFactory {
+ public:
+  virtual std::shared_ptr<AsyncSocketBase> newSocket(
+    int port, SocketAddress address, int backlog,
+    bool reuse, ServerSocketConfig& config) = 0;
+
+  virtual void stopSocket(
+    std::shared_ptr<AsyncSocketBase>& socket) = 0;
+
+  virtual void removeAcceptCB(std::shared_ptr<AsyncSocketBase> sock, Acceptor *callback, EventBase* base) = 0;
+  virtual void addAcceptCB(std::shared_ptr<AsyncSocketBase> sock, Acceptor* callback, EventBase* base) = 0 ;
+  virtual ~ServerSocketFactory() = default;
+};
+
+class AsyncServerSocketFactory : public ServerSocketFactory {
+ public:
+  std::shared_ptr<AsyncSocketBase> newSocket(
+      int port, SocketAddress address, int backlog, bool reuse,
+      ServerSocketConfig& config) {
+
+    auto socket = folly::AsyncServerSocket::newSocket();
+    socket->setReusePortEnabled(reuse);
+    socket->attachEventBase(EventBaseManager::get()->getEventBase());
+    if (port >= 0) {
+      socket->bind(port);
+    } else {
+      socket->bind(address);
+    }
+
+    socket->listen(config.acceptBacklog);
+    socket->startAccepting();
+
+    return socket;
+  }
+
+  virtual void stopSocket(
+    std::shared_ptr<AsyncSocketBase>& s) {
+    auto socket = std::dynamic_pointer_cast<AsyncServerSocket>(s);
+    DCHECK(socket);
+    socket->stopAccepting();
+    socket->detachEventBase();
+  }
+
+  virtual void removeAcceptCB(std::shared_ptr<AsyncSocketBase> s,
+                              Acceptor *callback, EventBase* base) {
+    auto socket = std::dynamic_pointer_cast<AsyncServerSocket>(s);
+    CHECK(socket);
+    socket->removeAcceptCallback(callback, base);
+  }
+
+  virtual void addAcceptCB(std::shared_ptr<AsyncSocketBase> s,
+                                 Acceptor* callback, EventBase* base) {
+    auto socket = std::dynamic_pointer_cast<AsyncServerSocket>(s);
+    CHECK(socket);
+    socket->addAcceptCallback(callback, base);
+  }
+};
+
+class AsyncUDPServerSocketFactory : public ServerSocketFactory {
+ public:
+  std::shared_ptr<AsyncSocketBase> newSocket(
+      int port, SocketAddress address, int backlog, bool reuse,
+      ServerSocketConfig& config) {
+
+    auto socket = std::make_shared<AsyncUDPServerSocket>(
+      EventBaseManager::get()->getEventBase());
+    //socket->setReusePortEnabled(reuse);
+    SocketAddress addressr("::1", port);
+    socket->bind(addressr);
+    socket->listen();
+
+    return socket;
+  }
+
+  virtual void stopSocket(
+    std::shared_ptr<AsyncSocketBase>& s) {
+    auto socket = std::dynamic_pointer_cast<AsyncUDPServerSocket>(s);
+    DCHECK(socket);
+    socket->close();
+  }
+
+  virtual void removeAcceptCB(std::shared_ptr<AsyncSocketBase> s,
+                              Acceptor *callback, EventBase* base) {
+  }
+
+  virtual void addAcceptCB(std::shared_ptr<AsyncSocketBase> s,
+                                 Acceptor* callback, EventBase* base) {
+    auto socket = std::dynamic_pointer_cast<AsyncUDPServerSocket>(s);
+    DCHECK(socket);
+    socket->addListener(base, callback);
+  }
+};
+
+} // namespace