Add evb change callback to SSL Socket
authorNeel Goyal <ngoyal@fb.com>
Tue, 13 Dec 2016 18:25:19 +0000 (10:25 -0800)
committerFacebook Github Bot <facebook-github-bot-bot@fb.com>
Tue, 13 Dec 2016 18:32:59 +0000 (10:32 -0800)
Summary: Allow observers to be notified when AsyncSocket attaches and detaches from EVB

Reviewed By: siyengar

Differential Revision: D4256175

fbshipit-source-id: a3ff96811f885e508f20cf11ce52e0f00e5ee461

folly/io/async/AsyncSSLSocket.cpp
folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/test/AsyncSSLSocketTest2.cpp
folly/io/async/test/AsyncSocketTest2.cpp

index aa34b1a..fefd9d0 100644 (file)
@@ -464,18 +464,36 @@ void AsyncSSLSocket::attachSSLContext(
   DCHECK(ctx->getSSLCtx());
   ctx_ = ctx;
 
+  // It's possible this could be attached before ssl_ is set up
+  if (!ssl_) {
+    return;
+  }
+
   // In order to call attachSSLContext, detachSSLContext must have been
-  // previously called which sets the socket's context to the dummy
-  // context. Thus we must acquire this lock.
+  // previously called.
+  // We need to update the initial_ctx if necessary
+  auto sslCtx = ctx->getSSLCtx();
+#ifndef OPENSSL_NO_TLSEXT
+  CRYPTO_add(&sslCtx->references, 1, CRYPTO_LOCK_SSL_CTX);
+  // note that detachSSLContext has already freed ssl_->initial_ctx
+  ssl_->initial_ctx = sslCtx;
+#endif
+  // Detach sets the socket's context to the dummy context. Thus we must acquire
+  // this lock.
   SpinLockGuard guard(dummyCtxLock);
-  SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
+  SSL_set_SSL_CTX(ssl_, sslCtx);
 }
 
 void AsyncSSLSocket::detachSSLContext() {
   DCHECK(ctx_);
   ctx_.reset();
-  // We aren't using the initial_ctx for now, and it can introduce race
-  // conditions in the destructor of the SSL object.
+  // It's possible for this to be called before ssl_ has been
+  // set up
+  if (!ssl_) {
+    return;
+  }
+// Detach the initial_ctx as well.  Internally w/ OPENSSL_NO_TLSEXT
+// it is used for session info.  It will be reattached in attachSSLContext
 #ifndef OPENSSL_NO_TLSEXT
   if (ssl_->initial_ctx) {
     SSL_CTX_free(ssl_->initial_ctx);
index 40cddc1..fb3aaed 100644 (file)
@@ -1121,6 +1121,9 @@ void AsyncSocket::attachEventBase(EventBase* eventBase) {
   eventBase_ = eventBase;
   ioHandler_.attachEventBase(eventBase);
   writeTimeout_.attachEventBase(eventBase);
+  if (evbChangeCb_) {
+    evbChangeCb_->evbAttached(this);
+  }
 }
 
 void AsyncSocket::detachEventBase() {
@@ -1133,6 +1136,9 @@ void AsyncSocket::detachEventBase() {
   eventBase_ = nullptr;
   ioHandler_.detachEventBase();
   writeTimeout_.detachEventBase();
+  if (evbChangeCb_) {
+    evbChangeCb_->evbDetached(this);
+  }
 }
 
 bool AsyncSocket::isDetachable() const {
index 973e493..2624d92 100644 (file)
@@ -94,6 +94,19 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
       noexcept = 0;
   };
 
+  class EvbChangeCallback {
+   public:
+    virtual ~EvbChangeCallback() = default;
+
+    // Called when the socket has been attached to a new EVB
+    // and is called from within that EVB thread
+    virtual void evbAttached(AsyncSocket* socket) = 0;
+
+    // Called when the socket is detached from an EVB and
+    // is called from the EVB thread being detached
+    virtual void evbDetached(AsyncSocket* socket) = 0;
+  };
+
   explicit AsyncSocket();
   /**
    * Create a new unconnected AsyncSocket.
@@ -560,6 +573,12 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
 
   void setBufferCallback(BufferCallback* cb);
 
+  // Callers should set this prior to connecting the socket for the safest
+  // behavior.
+  void setEvbChangedCallback(std::unique_ptr<EvbChangeCallback> cb) {
+    evbChangeCb_ = std::move(cb);
+  }
+
   /**
    * writeReturn is the total number of bytes written, or WRITE_ERROR on error.
    * If no data has been written, 0 is returned.
@@ -930,6 +949,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   bool tfoEnabled_{false};
   bool tfoAttempted_{false};
   bool tfoFinished_{false};
+
+  std::unique_ptr<EvbChangeCallback> evbChangeCb_{nullptr};
 };
 #ifdef _MSC_VER
 #pragma vtordisp(pop)
index e33df0a..4d766b6 100644 (file)
 
 #include <pthread.h>
 
+#include <folly/futures/Promise.h>
 #include <folly/io/async/AsyncSSLSocket.h>
 #include <folly/io/async/EventBase.h>
 #include <folly/io/async/SSLContext.h>
+#include <folly/io/async/ScopedEventBaseThread.h>
 #include <folly/portability/GTest.h>
 
 using std::string;
@@ -31,44 +33,100 @@ using std::list;
 
 namespace folly {
 
+struct EvbAndContext {
+  EvbAndContext() {
+    ctx_.reset(new SSLContext());
+    ctx_->setOptions(SSL_OP_NO_TICKET);
+    ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+  }
+
+  std::shared_ptr<AsyncSSLSocket> createSocket() {
+    return AsyncSSLSocket::newSocket(ctx_, getEventBase());
+  }
+
+  EventBase* getEventBase() {
+    return evb_.getEventBase();
+  }
+
+  void attach(AsyncSSLSocket& socket) {
+    socket.attachEventBase(getEventBase());
+    socket.attachSSLContext(ctx_);
+  }
+
+  folly::ScopedEventBaseThread evb_;
+  std::shared_ptr<SSLContext> ctx_;
+};
+
 class AttachDetachClient : public AsyncSocket::ConnectCallback,
                            public AsyncTransportWrapper::WriteCallback,
                            public AsyncTransportWrapper::ReadCallback {
  private:
-  EventBase *eventBase_;
+  // two threads here - we'll create the socket in one, connect
+  // in the other, and then read/write in the initial one
+  EvbAndContext t1_;
+  EvbAndContext t2_;
   std::shared_ptr<AsyncSSLSocket> sslSocket_;
-  std::shared_ptr<SSLContext> ctx_;
   folly::SocketAddress address_;
   char buf_[128];
   char readbuf_[128];
   uint32_t bytesRead_;
+  // promise to fulfill when done
+  folly::Promise<bool> promise_;
+
+  void detach() {
+    sslSocket_->detachEventBase();
+    sslSocket_->detachSSLContext();
+  }
+
  public:
-  AttachDetachClient(EventBase *eventBase, const folly::SocketAddress& address)
-      : eventBase_(eventBase), address_(address), bytesRead_(0) {
-    ctx_.reset(new SSLContext());
-    ctx_->setOptions(SSL_OP_NO_TICKET);
-    ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+  explicit AttachDetachClient(const folly::SocketAddress& address)
+      : address_(address), bytesRead_(0) {}
+
+  Future<bool> getFuture() {
+    return promise_.getFuture();
   }
 
   void connect() {
-    sslSocket_ = AsyncSSLSocket::newSocket(ctx_, eventBase_);
-    sslSocket_->connect(this, address_);
+    // create in one and then move to another
+    auto t1Evb = t1_.getEventBase();
+    t1Evb->runInEventBaseThread([this] {
+      sslSocket_ = t1_.createSocket();
+      // ensure we can detach and reattach the context before connecting
+      for (int i = 0; i < 1000; ++i) {
+        sslSocket_->detachSSLContext();
+        sslSocket_->attachSSLContext(t1_.ctx_);
+      }
+      // detach from t1 and connect in t2
+      detach();
+      auto t2Evb = t2_.getEventBase();
+      t2Evb->runInEventBaseThread([this] {
+        t2_.attach(*sslSocket_);
+        sslSocket_->connect(this, address_);
+      });
+    });
   }
 
   void connectSuccess() noexcept override {
+    auto t2Evb = t2_.getEventBase();
+    EXPECT_TRUE(t2Evb->isInEventBaseThread());
     cerr << "client SSL socket connected" << endl;
-
     for (int i = 0; i < 1000; ++i) {
       sslSocket_->detachSSLContext();
-      sslSocket_->attachSSLContext(ctx_);
+      sslSocket_->attachSSLContext(t2_.ctx_);
     }
 
-    EXPECT_EQ(ctx_->getSSLCtx()->references, 2);
-
-    sslSocket_->write(this, buf_, sizeof(buf_));
-    sslSocket_->setReadCB(this);
-    memset(readbuf_, 'b', sizeof(readbuf_));
-    bytesRead_ = 0;
+    // detach from t2 and then read/write in t1
+    t2Evb->runInEventBaseThread([this] {
+      detach();
+      auto t1Evb = t1_.getEventBase();
+      t1Evb->runInEventBaseThread([this] {
+        t1_.attach(*sslSocket_);
+        sslSocket_->write(this, buf_, sizeof(buf_));
+        sslSocket_->setReadCB(this);
+        memset(readbuf_, 'b', sizeof(readbuf_));
+        bytesRead_ = 0;
+      });
+    });
   }
 
   void connectErr(const AsyncSocketException& ex) noexcept override
@@ -96,14 +154,19 @@ class AttachDetachClient : public AsyncSocket::ConnectCallback,
 
   void readErr(const AsyncSocketException& ex) noexcept override {
     cerr << "client readError: " << ex.what() << endl;
+    promise_.setException(ex);
   }
 
   void readDataAvailable(size_t len) noexcept override {
+    EXPECT_TRUE(t1_.getEventBase()->isInEventBaseThread());
+    EXPECT_EQ(sslSocket_->getEventBase(), t1_.getEventBase());
     cerr << "client read data: " << len << endl;
     bytesRead_ += len;
     if (len == sizeof(buf_)) {
       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
       sslSocket_->closeNow();
+      sslSocket_.reset();
+      promise_.setValue(true);
     }
   }
 };
@@ -119,13 +182,12 @@ TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
   TestSSLServer server(&acceptCallback);
 
-  EventBase eventBase;
-  EventBaseAborter eba(&eventBase, 3000);
   std::shared_ptr<AttachDetachClient> client(
-    new AttachDetachClient(&eventBase, server.getAddress()));
+      new AttachDetachClient(server.getAddress()));
 
+  auto f = client->getFuture();
   client->connect();
-  eventBase.loop();
+  EXPECT_TRUE(f.within(std::chrono::seconds(3)).get());
 }
 
 }  // folly
index d4e0749..9fb53b2 100644 (file)
@@ -2996,4 +2996,24 @@ TEST(AsyncSocketTest, ConnectTFOWithBigData) {
   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
 }
 
+class MockEvbChangeCallback : public AsyncSocket::EvbChangeCallback {
+ public:
+  MOCK_METHOD1(evbAttached, void(AsyncSocket*));
+  MOCK_METHOD1(evbDetached, void(AsyncSocket*));
+};
+
+TEST(AsyncSocketTest, EvbCallbacks) {
+  auto cb = folly::make_unique<MockEvbChangeCallback>();
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+  InSequence seq;
+  EXPECT_CALL(*cb, evbDetached(socket.get())).Times(1);
+  EXPECT_CALL(*cb, evbAttached(socket.get())).Times(1);
+
+  socket->setEvbChangedCallback(std::move(cb));
+  socket->detachEventBase();
+  socket->attachEventBase(&evb);
+}
+
 #endif