AsyncSSLSocket StartTLS
authorDave Watson <davejwatson@fb.com>
Tue, 9 Jun 2015 18:34:10 +0000 (11:34 -0700)
committerSara Golemon <sgolemon@fb.com>
Tue, 9 Jun 2015 20:21:45 +0000 (13:21 -0700)
Summary:
Adds a StartTLS mode to AsyncSSLSocket.  Previously I could only find anyone doing something like this by using AsyncSocket, calling detachFd, then creating a new AsyncSSLSocket, and calling sslConn/sslAccept.

That had a couple downsides: 1) All pointers to the previous AsyncSocket become invalid and similarly 2) have to be super careful reads/writes happen on the correct socket, are flushed before changing socket types, etc.

This makes it super easy to just use the same AsyncSSLSocket for everything:
a) Create AsyncSSLSocket in StartTLS mode
b) send/recv anything
c) Call sslAccept/sslConn.  Existing writes are still flushed in the correct order, any additional writes are buffered until handshake completes
d) Start receiving encrypted data.

I made it a new mode (vs. the default), since it seems bad to unintentionally send unencrypted data.

Use case is easy secure thrift upgrade (similar to how current kerberos does it)

Test Plan: New unittest

Reviewed By: afrind@fb.com

Subscribers: doug, ssl-diffs@, folly-diffs@, yfeldblum, chalfant, haijunz, andrewcox, alandau, alikhtarov, jsedgwick, simpkins

FB internal diff: D2120114

Signature: t1:2120114:1433798448:caeddc8feb6cc10fb34200ba97ea323bcaf09f7a

folly/io/async/AsyncSSLSocket.cpp
folly/io/async/AsyncSSLSocket.h
folly/io/async/test/AsyncSSLSocketTest.cpp

index 3eeb932f0592cf9235c6a5b526da4efa105149a8..0489b6a4991de57f1d14c121079604a8f59885ff 100644 (file)
@@ -253,18 +253,22 @@ SSLException::SSLException(int sslError, int errno_copy):
  * Create a client AsyncSSLSocket
  */
 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
-                                 EventBase* evb) :
+                               EventBase* evb, bool deferSecurityNegotiation) :
     AsyncSocket(evb),
     ctx_(ctx),
     handshakeTimeout_(this, evb) {
   init();
+  if (deferSecurityNegotiation) {
+    sslState_ = STATE_UNENCRYPTED;
+  }
 }
 
 /**
  * Create a server/client AsyncSSLSocket
  */
 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
-                                 EventBase* evb, int fd, bool server) :
+                               EventBase* evb, int fd, bool server,
+                               bool deferSecurityNegotiation) :
     AsyncSocket(evb, fd),
     server_(server),
     ctx_(ctx),
@@ -274,6 +278,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
                               AsyncSSLSocket::sslInfoCallback);
   }
+  if (deferSecurityNegotiation) {
+    sslState_ = STATE_UNENCRYPTED;
+  }
 }
 
 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
@@ -283,8 +290,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
  */
 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
                                  EventBase* evb,
-                                 const std::string& serverName) :
-    AsyncSSLSocket(ctx, evb) {
+                               const std::string& serverName,
+                               bool deferSecurityNegotiation) :
+    AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
   tlsextHostname_ = serverName;
 }
 
@@ -294,8 +302,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
  */
 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
                                  EventBase* evb, int fd,
-                                 const std::string& serverName) :
-    AsyncSSLSocket(ctx, evb, fd, false) {
+                               const std::string& serverName,
+                               bool deferSecurityNegotiation) :
+    AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
   tlsextHostname_ = serverName;
 }
 #endif
@@ -374,7 +383,7 @@ void AsyncSSLSocket::shutdownWriteNow() {
 bool AsyncSSLSocket::good() const {
   return (AsyncSocket::good() &&
           (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
-           sslState_ == STATE_ESTABLISHED));
+           sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
 }
 
 // The TAsyncTransport definition of 'good' states that the transport is
@@ -468,7 +477,9 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
   verifyPeer_ = verifyPeer;
 
   // Make sure we're in the uninitialized state
-  if (!server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
+  if (!server_ || (sslState_ != STATE_UNINIT &&
+                   sslState_ != STATE_UNENCRYPTED) ||
+      handshakeCallback_ != nullptr) {
     return invalidState(callback);
   }
 
@@ -674,7 +685,9 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
   verifyPeer_ = verifyPeer;
 
   // Make sure we're in the uninitialized state
-  if (server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
+  if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
+                  STATE_UNENCRYPTED) ||
+      handshakeCallback_ != nullptr) {
     return invalidState(callback);
   }
 
@@ -1078,6 +1091,10 @@ AsyncSSLSocket::handleRead() noexcept {
 
 ssize_t
 AsyncSSLSocket::performRead(void* buf, size_t buflen) {
+  if (sslState_ == STATE_UNENCRYPTED) {
+    return AsyncSocket::performRead(buf, buflen);
+  }
+
   errno = 0;
   ssize_t bytes = SSL_read(ssl_, buf, buflen);
   if (server_ && renegotiateAttempted_) {
@@ -1169,6 +1186,10 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
                                       WriteFlags flags,
                                       uint32_t* countWritten,
                                       uint32_t* partialWritten) {
+  if (sslState_ == STATE_UNENCRYPTED) {
+    return AsyncSocket::performWrite(
+      vec, count, flags, countWritten, partialWritten);
+  }
   if (sslState_ != STATE_ESTABLISHED) {
     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
                << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
index cce4c18b5fd3fa7d452178133e9552fee9a58595..393b87650defdc9a6f636f86dba4daaf9e130565 100644 (file)
@@ -162,7 +162,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * Create a client AsyncSSLSocket
    */
   AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
-                  EventBase* evb);
+                 EventBase* evb, bool deferSecurityNegotiation = false);
 
   /**
    * Create a server/client AsyncSSLSocket from an already connected
@@ -178,9 +178,12 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * @param evb EventBase that will manage this socket.
    * @param fd  File descriptor to take over (should be a connected socket).
    * @param server Is socket in server mode?
+   * @param deferSecurityNegotiation
+   *          unencrypted data can be sent before sslConn/Accept
    */
   AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
-                  EventBase* evb, int fd, bool server = true);
+                 EventBase* evb, int fd,
+                 bool server = true, bool deferSecurityNegotiation = false);
 
 
   /**
@@ -188,9 +191,10 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   static std::shared_ptr<AsyncSSLSocket> newSocket(
     const std::shared_ptr<folly::SSLContext>& ctx,
-    EventBase* evb, int fd, bool server=true) {
+    EventBase* evb, int fd, bool server=true,
+    bool deferSecurityNegotiation = false) {
     return std::shared_ptr<AsyncSSLSocket>(
-      new AsyncSSLSocket(ctx, evb, fd, server),
+      new AsyncSSLSocket(ctx, evb, fd, server, deferSecurityNegotiation),
       Destructor());
   }
 
@@ -199,9 +203,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   static std::shared_ptr<AsyncSSLSocket> newSocket(
     const std::shared_ptr<folly::SSLContext>& ctx,
-    EventBase* evb) {
+    EventBase* evb, bool deferSecurityNegotiation = false) {
     return std::shared_ptr<AsyncSSLSocket>(
-      new AsyncSSLSocket(ctx, evb),
+      new AsyncSSLSocket(ctx, evb, deferSecurityNegotiation),
       Destructor());
   }
 
@@ -213,7 +217,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
                   EventBase* evb,
-                  const std::string& serverName);
+                 const std::string& serverName,
+                bool deferSecurityNegotiation = false);
 
   /**
    * Create a client AsyncSSLSocket from an already connected
@@ -233,14 +238,16 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
                   EventBase* evb,
                   int fd,
-                  const std::string& serverName);
+                 const std::string& serverName,
+                bool deferSecurityNegotiation = false);
 
   static std::shared_ptr<AsyncSSLSocket> newSocket(
     const std::shared_ptr<folly::SSLContext>& ctx,
     EventBase* evb,
-    const std::string& serverName) {
+    const std::string& serverName,
+    bool deferSecurityNegotiation = false) {
     return std::shared_ptr<AsyncSSLSocket>(
-      new AsyncSSLSocket(ctx, evb, serverName),
+      new AsyncSSLSocket(ctx, evb, serverName, deferSecurityNegotiation),
       Destructor());
   }
 #endif
@@ -336,6 +343,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   enum SSLStateEnum {
     STATE_UNINIT,
+    STATE_UNENCRYPTED,
     STATE_ACCEPTING,
     STATE_CACHE_LOOKUP,
     STATE_RSA_ASYNC_PENDING,
index 20f782a1e59d535dfa6d75c30cebff269f5c1806..b3759cfe8f4f4fe19be625d027cbd970f3216de6 100644 (file)
@@ -1262,8 +1262,90 @@ TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
   socket->setMinWriteSize(50000);
   EXPECT_EQ(50000, socket->getMinWriteSize());
 }
+
+class ReadCallbackTerminator : public ReadCallback {
+ public:
+  ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
+      : ReadCallback(wcb)
+      , base_(base) {}
+
+  // Do not write data back, terminate the loop.
+  void readDataAvailable(size_t len) noexcept override {
+    std::cerr << "readDataAvailable, len " << len << std::endl;
+
+    currentBuffer.length = len;
+
+    buffers.push_back(currentBuffer);
+    currentBuffer.reset();
+    state = STATE_SUCCEEDED;
+
+    socket_->setReadCB(nullptr);
+    base_->terminateLoopSoon();
+  }
+ private:
+  EventBase* base_;
+};
+
+
+/**
+ * Test a full unencrypted codepath
+ */
+TEST(AsyncSSLSocketTest, UnencryptedTest) {
+  EventBase base;
+
+  auto clientCtx = std::make_shared<folly::SSLContext>();
+  auto serverCtx = std::make_shared<folly::SSLContext>();
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, serverCtx);
+  auto client = AsyncSSLSocket::newSocket(
+                  clientCtx, &base, fds[0], false, true);
+  auto server = AsyncSSLSocket::newSocket(
+                  serverCtx, &base, fds[1], true, true);
+
+  ReadCallbackTerminator readCallback(&base, nullptr);
+  server->setReadCB(&readCallback);
+  readCallback.setSocket(server);
+
+  uint8_t buf[128];
+  memset(buf, 'a', sizeof(buf));
+  client->write(nullptr, buf, sizeof(buf));
+
+  // Check that bytes are unencrypted
+  char c;
+  EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
+  EXPECT_EQ('a', c);
+
+  EventBaseAborter eba(&base, 3000);
+  base.loop();
+
+  EXPECT_EQ(1, readCallback.buffers.size());
+  EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
+
+  server->setReadCB(&readCallback);
+
+  // Unencrypted
+  server->sslAccept(nullptr);
+  client->sslConn(nullptr);
+
+  // Do NOT wait for handshake, writing should be queued and happen after
+
+  client->write(nullptr, buf, sizeof(buf));
+
+  // Check that bytes are *not* unencrypted
+  char c2;
+  EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
+  EXPECT_NE('a', c2);
+
+
+  base.loop();
+
+  EXPECT_EQ(2, readCallback.buffers.size());
+  EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
 }
 
+} // namespace
+
 ///////////////////////////////////////////////////////////////////////////
 // init_unit_test_suite
 ///////////////////////////////////////////////////////////////////////////