Add TLS 1.2+ version for contexts
authorNeel Goyal <ngoyal@fb.com>
Tue, 1 Aug 2017 19:18:25 +0000 (12:18 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 1 Aug 2017 19:29:30 +0000 (12:29 -0700)
Summary: Add an SSL Version that specifies only TLS 1.2 and up.  This prevents any client with less than TLS 1.2 from connecting.

Reviewed By: knekritz

Differential Revision: D5537423

fbshipit-source-id: 131f5b124af379eaa2b443052be9b43290c41820

folly/io/async/SSLContext.cpp
folly/io/async/SSLContext.h
folly/io/async/test/AsyncSSLSocketTest2.cpp
folly/io/async/test/TestSSLServer.cpp
folly/io/async/test/TestSSLServer.h

index 3d440a8..45936d0 100644 (file)
@@ -49,6 +49,10 @@ SSLContext::SSLContext(SSLVersion version) {
     case SSLv3:
       opt = SSL_OP_NO_SSLv2;
       break;
+    case TLSv1_2:
+      opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 |
+          SSL_OP_NO_TLSv1_1;
+      break;
     default:
       // do nothing
       break;
index ded583f..d556806 100644 (file)
@@ -68,11 +68,11 @@ class PasswordCollector {
  */
 class SSLContext {
  public:
-
   enum SSLVersion {
-     SSLv2,
-     SSLv3,
-     TLSv1
+    SSLv2,
+    SSLv3,
+    TLSv1, // support TLS 1.0+
+    TLSv1_2, // support for only TLS 1.2+
   };
 
   /**
index eb5490c..ae9ef53 100644 (file)
@@ -190,6 +190,93 @@ TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
   EXPECT_TRUE(f.within(std::chrono::seconds(3)).get());
 }
 
+class ConnectClient : public AsyncSocket::ConnectCallback {
+ public:
+  ConnectClient() = default;
+
+  Future<bool> getFuture() {
+    return promise_.getFuture();
+  }
+
+  void connect(const folly::SocketAddress& addr) {
+    t1_.getEventBase()->runInEventBaseThread([&] {
+      socket_ = t1_.createSocket();
+      socket_->connect(this, addr);
+    });
+  }
+
+  void connectSuccess() noexcept override {
+    promise_.setValue(true);
+    socket_.reset();
+  }
+
+  void connectErr(const AsyncSocketException& /* ex */) noexcept override {
+    promise_.setValue(false);
+    socket_.reset();
+  }
+
+  void setCtx(std::shared_ptr<SSLContext> ctx) {
+    t1_.ctx_ = ctx;
+  }
+
+ private:
+  EvbAndContext t1_;
+  // promise to fulfill when done with a value of true if connect succeeded
+  folly::Promise<bool> promise_;
+  std::shared_ptr<AsyncSSLSocket> socket_;
+};
+
+class NoopReadCallback : public ReadCallbackBase {
+ public:
+  NoopReadCallback() : ReadCallbackBase(nullptr) {
+    state = STATE_SUCCEEDED;
+  }
+
+  void getReadBuffer(void** buf, size_t* lenReturn) override {
+    *buf = &buffer_;
+    *lenReturn = 1;
+  }
+  void readDataAvailable(size_t) noexcept override {}
+
+  uint8_t buffer_{0};
+};
+
+TEST(AsyncSSLSocketTest2, TestTLS12DefaultClient) {
+  // Start listening on a local port
+  NoopReadCallback readCallback;
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
+  auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
+  TestSSLServer server(&acceptCallback, ctx);
+  server.loadTestCerts();
+
+  // create a default client
+  auto c1 = std::make_unique<ConnectClient>();
+  auto f1 = c1->getFuture();
+  c1->connect(server.getAddress());
+  EXPECT_TRUE(f1.within(std::chrono::seconds(3)).get());
+}
+
+TEST(AsyncSSLSocketTest2, TestTLS12BadClient) {
+  // Start listening on a local port
+  NoopReadCallback readCallback;
+  HandshakeCallback handshakeCallback(
+      &readCallback, HandshakeCallback::EXPECT_ERROR);
+  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
+  auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
+  TestSSLServer server(&acceptCallback, ctx);
+  server.loadTestCerts();
+
+  // create a client that doesn't speak TLS 1.2
+  auto c2 = std::make_unique<ConnectClient>();
+  auto clientCtx = std::make_shared<SSLContext>();
+  clientCtx->setOptions(SSL_OP_NO_TLSv1_2);
+  c2->setCtx(clientCtx);
+  auto f2 = c2->getFuture();
+  c2->connect(server.getAddress());
+  EXPECT_FALSE(f2.within(std::chrono::seconds(3)).get());
+}
+
 } // namespace folly
 
 int main(int argc, char *argv[]) {
index bc127db..46d6743 100644 (file)
@@ -40,6 +40,11 @@ TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
   init(enableTFO);
 }
 
+void TestSSLServer::loadTestCerts() {
+  ctx_->loadCertificate(kTestCert);
+  ctx_->loadPrivateKey(kTestKey);
+}
+
 TestSSLServer::TestSSLServer(
     SSLServerAcceptCallbackBase* acb,
     std::shared_ptr<SSLContext> ctx,
index a25bbb8..506b3d1 100644 (file)
@@ -111,6 +111,8 @@ class TestSSLServer {
     return evb_;
   }
 
+  void loadTestCerts();
+
   const SocketAddress& getAddress() const {
     return address_;
   }