From 3e19d28a142149241d81c5e736aa4117fe7cbec8 Mon Sep 17 00:00:00 2001 From: Xiangyu Bu Date: Mon, 27 Nov 2017 15:37:45 -0800 Subject: [PATCH 1/1] Revise API to load cert/key in SSLContext. Summary: When loading cert/key pair, order matters: (a) Wrong key will fail to load if a cert is loaded; (b) Wrong cert will succeed to load even if a private key is loaded. So this diff adds: (1) SSLContext::checkPrivateKey() -- must call for case (b). (2) SSLContext::loadCertKeyPairFromBufferPEM() -- use this if one loads both cert and key. Guaranteed to throw if cert/key mismatch. Reviewed By: yfeldblum Differential Revision: D6416280 fbshipit-source-id: 8ae370883d46e9b5afb69c506c09fbf7ba82b1b9 --- folly/io/async/SSLContext.cpp | 20 ++++++ folly/io/async/SSLContext.h | 36 +++++++++++ folly/io/async/test/SSLContextTest.cpp | 89 ++++++++++++++++++++++++++ 3 files changed, 145 insertions(+) diff --git a/folly/io/async/SSLContext.cpp b/folly/io/async/SSLContext.cpp index 498d8dfb..7ce05f44 100644 --- a/folly/io/async/SSLContext.cpp +++ b/folly/io/async/SSLContext.cpp @@ -287,6 +287,26 @@ void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) { } } +void SSLContext::loadCertKeyPairFromBufferPEM( + folly::StringPiece cert, + folly::StringPiece pkey) { + loadCertificateFromBufferPEM(cert); + loadPrivateKeyFromBufferPEM(pkey); +} + +void SSLContext::loadCertKeyPairFromFiles( + const char* certPath, + const char* keyPath, + const char* certFormat, + const char* keyFormat) { + loadCertificate(certPath, certFormat); + loadPrivateKey(keyPath, keyFormat); +} + +bool SSLContext::isCertKeyPairValid() const { + return SSL_CTX_check_private_key(ctx_) == 1; +} + void SSLContext::loadTrustedCertificates(const char* path) { if (path == nullptr) { throw std::invalid_argument("loadTrustedCertificates: is nullptr"); diff --git a/folly/io/async/SSLContext.h b/folly/io/async/SSLContext.h index c8db033e..bdd04509 100644 --- a/folly/io/async/SSLContext.h +++ b/folly/io/async/SSLContext.h @@ -275,6 +275,7 @@ class SSLContext { * @param cert A PEM formatted certificate */ virtual void loadCertificateFromBufferPEM(folly::StringPiece cert); + /** * Load private key. * @@ -288,6 +289,41 @@ class SSLContext { * @param pkey A PEM formatted key */ virtual void loadPrivateKeyFromBufferPEM(folly::StringPiece pkey); + + /** + * Load cert and key from PEM buffers. Guaranteed to throw if cert and + * private key mismatch so no need to call isCertKeyPairValid. + * + * @param cert A PEM formatted certificate + * @param pkey A PEM formatted key + */ + virtual void loadCertKeyPairFromBufferPEM( + folly::StringPiece cert, + folly::StringPiece pkey); + + /** + * Load cert and key from files. Guaranteed to throw if cert and key mismatch. + * Equivalent to calling loadCertificate() and loadPrivateKey(). + * + * @param certPath Path to the certificate file + * @param keyPath Path to the private key file + * @param certFormat Certificate file format + * @param keyFormat Private key file format + */ + virtual void loadCertKeyPairFromFiles( + const char* certPath, + const char* keyPath, + const char* certFormat = "PEM", + const char* keyFormat = "PEM"); + + /** + * Call after both cert and key are loaded to check if cert matches key. + * Must call if private key is loaded before loading the cert. + * No need to call if cert is loaded first before private key. + * @return true if matches, or false if mismatch. + */ + virtual bool isCertKeyPairValid() const; + /** * Load trusted certificates from specified file. * diff --git a/folly/io/async/test/SSLContextTest.cpp b/folly/io/async/test/SSLContextTest.cpp index 955a1fc8..fafa433b 100644 --- a/folly/io/async/test/SSLContextTest.cpp +++ b/folly/io/async/test/SSLContextTest.cpp @@ -15,6 +15,7 @@ */ #include +#include #include #include @@ -48,4 +49,92 @@ TEST_F(SSLContextTest, TestSetCipherList) { ctx.setCipherList(ciphers); verifySSLCipherList(ciphers); } + +TEST_F(SSLContextTest, TestLoadCertKey) { + std::string certData, keyData, anotherKeyData; + const char* certPath = "folly/io/async/test/certs/tests-cert.pem"; + const char* keyPath = "folly/io/async/test/certs/tests-key.pem"; + const char* anotherKeyPath = "folly/io/async/test/certs/client_key.pem"; + folly::readFile(certPath, certData); + folly::readFile(keyPath, keyData); + folly::readFile(anotherKeyPath, anotherKeyData); + + { + SCOPED_TRACE("Valid cert/key pair from buffer"); + SSLContext tmpCtx; + tmpCtx.loadCertificateFromBufferPEM(certData); + tmpCtx.loadPrivateKeyFromBufferPEM(keyData); + EXPECT_TRUE(tmpCtx.isCertKeyPairValid()); + } + + { + SCOPED_TRACE("Valid cert/key pair from files"); + SSLContext tmpCtx; + tmpCtx.loadCertificate(certPath); + tmpCtx.loadPrivateKey(keyPath); + EXPECT_TRUE(tmpCtx.isCertKeyPairValid()); + } + + { + SCOPED_TRACE("Invalid cert/key pair from file. Load cert first"); + SSLContext tmpCtx; + tmpCtx.loadCertificate(certPath); + EXPECT_THROW(tmpCtx.loadPrivateKey(anotherKeyPath), std::runtime_error); + } + + { + SCOPED_TRACE("Invalid cert/key pair from file. Load key first"); + SSLContext tmpCtx; + tmpCtx.loadPrivateKey(anotherKeyPath); + tmpCtx.loadCertificate(certPath); + EXPECT_FALSE(tmpCtx.isCertKeyPairValid()); + } + + { + SCOPED_TRACE("Invalid key/cert pair from buf. Load cert first"); + SSLContext tmpCtx; + tmpCtx.loadCertificateFromBufferPEM(certData); + EXPECT_THROW( + tmpCtx.loadPrivateKeyFromBufferPEM(anotherKeyData), std::runtime_error); + } + + { + SCOPED_TRACE("Invalid key/cert pair from buf. Load key first"); + SSLContext tmpCtx; + tmpCtx.loadPrivateKeyFromBufferPEM(anotherKeyData); + tmpCtx.loadCertificateFromBufferPEM(certData); + EXPECT_FALSE(tmpCtx.isCertKeyPairValid()); + } + + { + SCOPED_TRACE( + "loadCertKeyPairFromBufferPEM() must throw when cert/key mismatch"); + SSLContext tmpCtx; + EXPECT_THROW( + tmpCtx.loadCertKeyPairFromBufferPEM(certData, anotherKeyData), + std::runtime_error); + } + + { + SCOPED_TRACE( + "loadCertKeyPairFromBufferPEM() must succeed when cert/key match"); + SSLContext tmpCtx; + tmpCtx.loadCertKeyPairFromBufferPEM(certData, keyData); + } + + { + SCOPED_TRACE( + "loadCertKeyPairFromFiles() must throw when cert/key mismatch"); + SSLContext tmpCtx; + EXPECT_THROW( + tmpCtx.loadCertKeyPairFromFiles(certPath, anotherKeyPath), + std::runtime_error); + } + + { + SCOPED_TRACE("loadCertKeyPairFromFiles() must succeed when cert/key match"); + SSLContext tmpCtx; + tmpCtx.loadCertKeyPairFromFiles(certPath, keyPath); + } +} } // namespace folly -- 2.34.1