Add TFO support to AsyncSSLSocket
authorSubodh Iyengar <subodh@fb.com>
Tue, 7 Jun 2016 14:43:37 +0000 (07:43 -0700)
committerFacebook Github Bot 7 <facebook-github-bot-7-bot@fb.com>
Tue, 7 Jun 2016 14:53:33 +0000 (07:53 -0700)
Summary:
This adds TFO support to AsyncSSLSocket which
uses the support for TFO from AsyncSocket.

Because of the way AsyncSSLSocket inherits from
AsyncSocket it is tricky.

The following changes were made:
1. Openssl internally will treat only errors with return
code -1 as READ_REQUIRED or WRITE_REQUIRED errors. So this
diff changes the return value of the errors in the TFO fallback
cases to -1.

2. In case we fallback after SSL_connect() to a normal connect,
we would have to restart the connection process after connect
succeeds. To do this this overrides the connection success callback
and restarts the connection before sending the callback to AsyncSocket
because sometimes callbacks might synchronously call sslConn() in the
normal connect cases.

3. Delegated bioWrite to call sendSocketMessage instead of sendmsg directly.

Reviewed By: djwatson

Differential Revision: D3391735

fbshipit-source-id: 61434f6de4a9c3d03973c9ab9e51eb49e751e5cf

folly/io/async/AsyncSSLSocket.cpp
folly/io/async/AsyncSSLSocket.h
folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/test/AsyncSSLSocketTest.cpp
folly/io/async/test/AsyncSSLSocketTest.h
folly/io/async/test/BlockingSocket.h
folly/io/async/test/SocketClient.cpp

index 4fa6b6fef26b57f6cbb62cb1c083a07f40fa1e3b..f91416ed61ac3c0dcff529b329a4165ebcf43719 100644 (file)
@@ -1084,8 +1084,9 @@ AsyncSSLSocket::handleConnect() noexcept {
     return AsyncSocket::handleConnect();
   }
 
-  assert(state_ == StateEnum::ESTABLISHED &&
-         sslState_ == STATE_CONNECTING);
+  assert(
+      (state_ == StateEnum::FAST_OPEN || state_ == StateEnum::ESTABLISHED) &&
+      sslState_ == STATE_CONNECTING);
   assert(ssl_);
 
   int ret = SSL_connect(ssl_);
@@ -1138,6 +1139,16 @@ AsyncSSLSocket::handleConnect() noexcept {
   AsyncSocket::handleInitialReadWrite();
 }
 
+void AsyncSSLSocket::invokeConnectSuccess() {
+  if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
+    // If we failed TFO, we'd fall back to trying to connect the socket,
+    // when we succeed we should handle the writes that caused us to start
+    // TFO.
+    handleWrite();
+  }
+  AsyncSocket::invokeConnectSuccess();
+}
+
 void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
   // turn on the buffer movable in openssl
@@ -1498,7 +1509,6 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
 }
 
 int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
-  int ret;
   struct msghdr msg;
   struct iovec iov;
   int flags = 0;
@@ -1521,17 +1531,20 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
     flags = MSG_EOR;
   }
 
-  ret = sendmsg(BIO_get_fd(b, nullptr), &msg, flags);
+  auto result =
+      tsslSock->sendSocketMessage(BIO_get_fd(b, nullptr), &msg, flags);
   BIO_clear_retry_flags(b);
-  if (ret <= 0) {
-    if (BIO_sock_should_retry(ret))
+  if (!result.exception && result.writeReturn <= 0) {
+    if (BIO_sock_should_retry(result.writeReturn)) {
       BIO_set_retry_write(b);
+    }
   }
-  return ret;
+  return result.writeReturn;
 }
 
-int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
-                                       X509_STORE_CTX* x509Ctx) {
+int AsyncSSLSocket::sslVerifyCallback(
+    int preverifyOk,
+    X509_STORE_CTX* x509Ctx) {
   SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
     x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
   AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
index 40ceb87aa51a2935f604aa5e755b23a049a1f5d9..bd4f76d12ab2080327fee9996e4d18147a7fb08a 100644 (file)
@@ -798,6 +798,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   void invokeHandshakeErr(const AsyncSocketException& ex);
   void invokeHandshakeCB();
 
+  void invokeConnectSuccess() override;
+
   void cacheLocalPeerAddr();
 
   static void sslInfoCallback(const SSL *ssl, int type, int val);
index cffe1f82ee1f9f2c2cff8ad1fe08c98a8dfe7c84..6bb9126efd7a086155f365b4985a49b119c8c6f4 100644 (file)
@@ -1752,9 +1752,8 @@ ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
   return detail::tfo_sendmsg(fd, msg, msg_flags);
 }
 
-AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
-    struct msghdr* msg,
-    int msg_flags) {
+AsyncSocket::WriteResult
+AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
   ssize_t totalWritten = 0;
   if (state_ == StateEnum::FAST_OPEN) {
     sockaddr_storage addr;
@@ -1778,11 +1777,9 @@ AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
         return WriteResult(
             WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
       }
-      // Let's fake it that no bytes were written.
-      // Some clients check errno even if return code is 0, so we
-      // set it just in case.
+      // Let's fake it that no bytes were written and return an errno.
       errno = EAGAIN;
-      totalWritten = 0;
+      totalWritten = -1;
     } else if (errno == EOPNOTSUPP) {
       VLOG(4) << "TFO not supported";
       // Try falling back to connecting.
@@ -1797,10 +1794,8 @@ AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
         }
         // If there was no exception during connections,
         // we would return that no bytes were written.
-        // Some clients check errno even if return code is 0, so we
-        // set it just in case.
         errno = EAGAIN;
-        totalWritten = 0;
+        totalWritten = -1;
       } catch (const AsyncSocketException& ex) {
         return WriteResult(
             WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
@@ -1816,7 +1811,7 @@ AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
               AsyncSocketException::UNKNOWN, "No more free local ports"));
     }
   } else {
-    totalWritten = ::sendmsg(fd_, msg, msg_flags);
+    totalWritten = ::sendmsg(fd, msg, msg_flags);
   }
   return WriteResult(totalWritten);
 }
@@ -1855,7 +1850,7 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
     // marks that this is the last byte of a record (response)
     msg_flags |= MSG_EOR;
   }
-  auto writeResult = sendSocketMessage(&msg, msg_flags);
+  auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
   auto totalWritten = writeResult.writeReturn;
   if (totalWritten < 0) {
     if (!writeResult.exception && errno == EAGAIN) {
index f3e605d251cb95533315caf31af03d562f5027fb..36949725c3558639de74e124f423430ee951e788 100644 (file)
@@ -817,7 +817,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    * @param msg       Message to send
    * @param msg_flags Flags to pass to sendmsg
    */
-  AsyncSocket::WriteResult sendSocketMessage(struct msghdr* msg, int msg_flags);
+  AsyncSocket::WriteResult
+  sendSocketMessage(int fd, struct msghdr* msg, int msg_flags);
 
   virtual ssize_t tfoSendMsg(int fd, struct msghdr* msg, int msg_flags);
 
@@ -855,7 +856,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   void failWrite(const char* fn, const AsyncSocketException& ex);
   void failAllWrites(const AsyncSocketException& ex);
   void invokeConnectErr(const AsyncSocketException& ex);
-  void invokeConnectSuccess();
+  virtual void invokeConnectSuccess();
   void invalidState(ConnectCallback* callback);
   void invalidState(ReadCallback* callback);
   void invalidState(WriteCallback* callback);
index d4078e56b01818c2d5f74234c6bd5c00a98f1c96..fd2b2521f7f39893cc37a025aae94d52c4a122a2 100644 (file)
  */
 #include <folly/io/async/test/AsyncSSLSocketTest.h>
 
-#include <signal.h>
 #include <pthread.h>
+#include <signal.h>
 
+#include <folly/SocketAddress.h>
 #include <folly/io/async/AsyncSSLSocket.h>
 #include <folly/io/async/EventBase.h>
-#include <folly/SocketAddress.h>
 #include <folly/portability/Sockets.h>
 #include <folly/portability/Unistd.h>
 
 #include <folly/io/async/test/BlockingSocket.h>
 
-#include <fstream>
+#include <fcntl.h>
+#include <folly/io/Cursor.h>
 #include <gtest/gtest.h>
+#include <openssl/bio.h>
+#include <sys/types.h>
+#include <fstream>
 #include <iostream>
 #include <list>
 #include <set>
-#include <fcntl.h>
-#include <openssl/bio.h>
-#include <sys/types.h>
-#include <folly/io/Cursor.h>
+
+#include <gmock/gmock.h>
 
 using std::string;
 using std::vector;
@@ -43,6 +45,8 @@ using std::cerr;
 using std::endl;
 using std::list;
 
+using namespace testing;
+
 namespace folly {
 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
@@ -55,7 +59,7 @@ const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
 constexpr size_t SSLClient::kMaxReadBufferSz;
 constexpr size_t SSLClient::kMaxReadsPerEvent;
 
-TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
+TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
     : ctx_(new folly::SSLContext),
       acb_(acb),
       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
@@ -67,7 +71,13 @@ TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
   acb_->ctx_ = ctx_;
   acb_->base_ = &evb_;
 
-  //set up the listening socket
+  // Enable TFO
+  if (enableTFO) {
+    LOG(INFO) << "server TFO enabled";
+    socket_->setTFOEnabled(true, 1000);
+  }
+
+  // set up the listening socket
   socket_->bind(0);
   socket_->getAddress(&address_);
   socket_->listen(100);
@@ -1674,6 +1684,203 @@ TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
             std::string::npos);
 }
 
+#if FOLLY_ALLOW_TFO
+
+class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
+ public:
+  using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
+
+  explicit MockAsyncTFOSSLSocket(
+      std::shared_ptr<folly::SSLContext> sslCtx,
+      EventBase* evb)
+      : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
+
+  MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
+};
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, true);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket =
+      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+  socket->enableTFO();
+  socket->open();
+
+  // write()
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  socket->write(buf.data(), buf.size());
+
+  // read()
+  std::array<uint8_t, 128> readbuf;
+  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+  // close()
+  socket->close();
+}
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, false);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket =
+      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+  socket->enableTFO();
+  socket->open();
+
+  // write()
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  socket->write(buf.data(), buf.size());
+
+  // read()
+  std::array<uint8_t, 128> readbuf;
+  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+  // close()
+  socket->close();
+}
+
+class ConnCallback : public AsyncSocket::ConnectCallback {
+ public:
+  virtual void connectSuccess() noexcept override {
+    state = State::SUCCESS;
+  }
+
+  virtual void connectErr(const AsyncSocketException&) noexcept override {
+    state = State::ERROR;
+  }
+
+  enum class State { WAITING, SUCCESS, ERROR };
+
+  State state{State::WAITING};
+};
+
+MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
+    EventBase* evb,
+    const SocketAddress& address) {
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket = MockAsyncTFOSSLSocket::UniquePtr(
+      new MockAsyncTFOSSLSocket(sslContext, evb));
+  socket->enableTFO();
+
+  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+      .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
+        sockaddr_storage addr;
+        auto len = address.getAddress(&addr);
+        return connect(fd, (const struct sockaddr*)&addr, len);
+      }));
+  return socket;
+}
+
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, true);
+
+  EventBase evb;
+
+  auto socket = setupSocketWithFallback(&evb, server.getAddress());
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
+
+  evb.runInEventBaseThread([&] { socket->detachEventBase(); });
+  evb.loop();
+
+  BlockingSocket sock(std::move(socket));
+  // write()
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  sock.write(buf.data(), buf.size());
+
+  // read()
+  std::array<uint8_t, 128> readbuf;
+  uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+  // close()
+  sock.close();
+}
+
+TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadErrorCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, true);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket =
+      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+  socket->enableTFO();
+  EXPECT_THROW(
+      socket->open(std::chrono::milliseconds(1)), AsyncSocketException);
+}
+
+TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadErrorCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, true);
+
+  EventBase evb;
+
+  auto socket = setupSocketWithFallback(&evb, server.getAddress());
+  ConnCallback ccb;
+  // Set a short timeout
+  socket->connect(&ccb, server.getAddress(), 1);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+}
+
+#endif
+
 } // namespace
 
 ///////////////////////////////////////////////////////////////////////////
index 43ada2497afea62d0b8a5d0bc95da2a50a9fa485..e4b512496c3941afbe5ea2f2cae69597c1daa46d 100644 (file)
@@ -607,7 +607,9 @@ class TestSSLServer {
  public:
   // Create a TestSSLServer.
   // This immediately starts listening on the given port.
-  explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
+  explicit TestSSLServer(
+      SSLServerAcceptCallbackBase* acb,
+      bool enableTFO = false);
 
   // Kill the thread.
   ~TestSSLServer() {
index 3830648e61d43a9e3d0407857a3f8287f9fc0c1e..b3713fc241bdfe55ec484a64acce1608934f0603 100644 (file)
 #pragma once
 
 #include <folly/Optional.h>
-#include <folly/io/async/SSLContext.h>
-#include <folly/io/async/AsyncSocket.h>
 #include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/SSLContext.h>
 
 class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
                        public folly::AsyncTransportWrapper::ReadCallback,
-                       public folly::AsyncTransportWrapper::WriteCallback
-{
+                       public folly::AsyncTransportWrapper::WriteCallback {
  public:
   explicit BlockingSocket(int fd)
-    : sock_(new folly::AsyncSocket(&eventBase_, fd)) {
-  }
+      : sock_(new folly::AsyncSocket(&eventBase_, fd)) {}
 
-  BlockingSocket(folly::SocketAddress address,
-                 std::shared_ptr<folly::SSLContext> sslContext)
-    : sock_(sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_) :
-            new folly::AsyncSocket(&eventBase_)),
-    address_(address) {}
+  BlockingSocket(
+      folly::SocketAddress address,
+      std::shared_ptr<folly::SSLContext> sslContext)
+      : sock_(
+            sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_)
+                       : new folly::AsyncSocket(&eventBase_)),
+        address_(address) {}
 
   explicit BlockingSocket(folly::AsyncSocket::UniquePtr socket)
       : sock_(std::move(socket)) {
     sock_->attachEventBase(&eventBase_);
   }
 
+  void enableTFO() {
+    sock_->enableTFO();
+  }
+
   void setAddress(folly::SocketAddress address) {
     address_ = address;
   }
 
-  void open() {
-    sock_->connect(this, address_);
+  void open(
+      std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) {
+    sock_->connect(this, address_, timeout.count());
     eventBase_.loop();
     if (err_.hasValue()) {
       throw err_.value();
@@ -54,7 +59,9 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
   void close() {
     sock_->close();
   }
-  void closeWithReset() { sock_->closeWithReset(); }
+  void closeWithReset() {
+    sock_->closeWithReset();
+  }
 
   int32_t write(uint8_t const* buf, size_t len) {
     sock_->write(this, buf, len);
@@ -67,11 +74,11 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
 
   void flush() {}
 
-  int32_t readAll(uint8_t *buf, size_t len) {
+  int32_t readAll(uint8_tbuf, size_t len) {
     return readHelper(buf, len, true);
   }
 
-  int32_t read(uint8_t *buf, size_t len) {
+  int32_t read(uint8_tbuf, size_t len) {
     return readHelper(buf, len, false);
   }
 
@@ -83,7 +90,7 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
   folly::EventBase eventBase_;
   folly::AsyncSocket::UniquePtr sock_;
   folly::Optional<folly::AsyncSocketException> err_;
-  uint8_t *readBuf_{nullptr};
+  uint8_treadBuf_{nullptr};
   size_t readLen_{0};
   folly::SocketAddress address_;
 
@@ -102,18 +109,18 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
       sock_->setReadCB(nullptr);
     }
   }
-  void readEOF() noexcept override {
-  }
+  void readEOF() noexcept override {}
   void readErr(const folly::AsyncSocketException& ex) noexcept override {
     err_ = ex;
   }
   void writeSuccess() noexcept override {}
-  void writeErr(size_t /* bytesWritten */,
-                const folly::AsyncSocketException& ex) noexcept override {
+  void writeErr(
+      size_t /* bytesWritten */,
+      const folly::AsyncSocketException& ex) noexcept override {
     err_ = ex;
   }
 
-  int32_t readHelper(uint8_t *buf, size_t len, bool all) {
+  int32_t readHelper(uint8_tbuf, size_t len, bool all) {
     if (!sock_->good()) {
       return 0;
     }
@@ -132,8 +139,8 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
       throw err_.value();
     }
     if (all && readLen_ > 0) {
-      throw folly::AsyncSocketException(folly::AsyncSocketException::UNKNOWN,
-                                        "eof");
+      throw folly::AsyncSocketException(
+          folly::AsyncSocketException::UNKNOWN, "eof");
     }
     return len - readLen_;
   }
index 7f20d480e20695eeaf7178c1f4d916078baaf07b..23bef722934fc51d60aa2734c9c351e175735482 100644 (file)
@@ -24,6 +24,7 @@ DEFINE_string(host, "localhost", "Host");
 DEFINE_int32(port, 0, "port");
 DEFINE_bool(tfo, false, "enable tfo");
 DEFINE_string(msg, "", "Message to send");
+DEFINE_bool(ssl, false, "use ssl");
 
 int main(int argc, char** argv) {
   gflags::ParseCommandLineFlags(&argc, &argv, true);
@@ -35,7 +36,13 @@ int main(int argc, char** argv) {
 
   // Prep the socket
   EventBase evb;
-  AsyncSocket::UniquePtr socket(new AsyncSocket(&evb));
+  AsyncSocket::UniquePtr socket;
+  if (FLAGS_ssl) {
+    auto sslContext = std::make_shared<SSLContext>();
+    socket = AsyncSocket::UniquePtr(new AsyncSSLSocket(sslContext, &evb));
+  } else {
+    socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
+  }
   socket->detachEventBase();
 
   if (FLAGS_tfo) {