Fix OpenSSLUtils to not include headers the wrong way
[folly.git] / folly / io / async / ssl / OpenSSLUtils.cpp
index 309327ed2de7f08da544f4e15d69d0dc6d1016f6..80d90ec04c92e409dbd618cf5618f0b4cf711bc0 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+
 #include <folly/io/async/ssl/OpenSSLUtils.h>
-#include <folly/ScopeGuard.h>
-#include <folly/portability/Sockets.h>
+
 #include <glog/logging.h>
-#include <openssl/bio.h>
-#include <openssl/err.h>
-#include <openssl/rand.h>
-#include <openssl/ssl.h>
-#include <openssl/x509v3.h>
+
 #include <unordered_map>
 
-#define OPENSSL_IS_101 (OPENSSL_VERSION_NUMBER >= 0x1000105fL && \
-                         OPENSSL_VERSION_NUMBER < 0x1000200fL)
-#define OPENSSL_IS_102 (OPENSSL_VERSION_NUMBER >= 0x1000200fL && \
-                        OPENSSL_VERSION_NUMBER < 0x10100000L)
-#define OPENSSL_IS_110 (OPENSSL_VERSION_NUMBER >= 0x10100000L)
+#include <folly/ScopeGuard.h>
+#include <folly/portability/Sockets.h>
+#include <folly/portability/Unistd.h>
 
 namespace {
-#if defined(OPENSSL_IS_BORINGSSL)
+#ifdef OPENSSL_IS_BORINGSSL
 // BoringSSL doesn't (as of May 2016) export the equivalent
 // of BIO_sock_should_retry, so this is one way around it :(
 static int boringssl_bio_fd_should_retry(int err);
@@ -45,7 +39,7 @@ namespace ssl {
 bool OpenSSLUtils::getTLSMasterKey(
     const SSL_SESSION* session,
     MutableByteRange keyOut) {
-#if OPENSSL_IS_101 || OPENSSL_IS_102
+#if FOLLY_OPENSSL_IS_101 || FOLLY_OPENSSL_IS_102
   if (session &&
       session->master_key_length == static_cast<int>(keyOut.size())) {
     auto masterKey = session->master_key;
@@ -53,6 +47,9 @@ bool OpenSSLUtils::getTLSMasterKey(
         masterKey, masterKey + session->master_key_length, keyOut.begin());
     return true;
   }
+#else
+  (SSL_SESSION*)session;
+  (MutableByteRange) keyOut;
 #endif
   return false;
 }
@@ -60,13 +57,16 @@ bool OpenSSLUtils::getTLSMasterKey(
 bool OpenSSLUtils::getTLSClientRandom(
     const SSL* ssl,
     MutableByteRange randomOut) {
-#if OPENSSL_IS_101 || OPENSSL_IS_102
+#if FOLLY_OPENSSL_IS_101 || FOLLY_OPENSSL_IS_102
   if ((SSL_version(ssl) >> 8) == TLS1_VERSION_MAJOR && ssl->s3 &&
       randomOut.size() == SSL3_RANDOM_SIZE) {
     auto clientRandom = ssl->s3->client_random;
     std::copy(clientRandom, clientRandom + SSL3_RANDOM_SIZE, randomOut.begin());
     return true;
   }
+#else
+  (SSL*)ssl;
+  (MutableByteRange) randomOut;
 #endif
   return false;
 }
@@ -126,7 +126,7 @@ bool OpenSSLUtils::validatePeerCertNames(X509* cert,
     if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) {
       // Extra const-ness for paranoia
       unsigned char const* const rawIpStr = name->d.iPAddress->data;
-      int const rawIpLen = name->d.iPAddress->length;
+      size_t const rawIpLen = size_t(name->d.iPAddress->length);
 
       if (rawIpLen == 4 && addr4 != nullptr) {
         if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) {
@@ -169,10 +169,8 @@ static std::unordered_map<uint16_t, std::string> getOpenSSLCipherNames() {
   };
 
   STACK_OF(SSL_CIPHER)* sk = SSL_get_ciphers(ssl);
-  for (size_t i = 0; i < (size_t)sk_SSL_CIPHER_num(sk); i++) {
-    SSL_CIPHER* c;
-
-    c = sk_SSL_CIPHER_value(sk, i);
+  for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
+    const SSL_CIPHER* c = sk_SSL_CIPHER_value(sk, i);
     unsigned long id = SSL_CIPHER_get_id(c);
     // OpenSSL 1.0.2 and prior does weird things such as stuff the SSL/TLS
     // version into the top 16 bits. Let's ignore those for now. This is
@@ -197,17 +195,56 @@ const std::string& OpenSSLUtils::getCipherName(uint16_t cipherCode) {
   }
 }
 
+void OpenSSLUtils::setSSLInitialCtx(SSL* ssl, SSL_CTX* ctx) {
+  (void)ssl;
+  (void)ctx;
+#if !FOLLY_OPENSSL_IS_110 && !defined(OPENSSL_NO_TLSEXT)
+  if (ssl) {
+    ssl->initial_ctx = ctx;
+  }
+#endif
+}
+
+SSL_CTX* OpenSSLUtils::getSSLInitialCtx(SSL* ssl) {
+  (void)ssl;
+#if !FOLLY_OPENSSL_IS_110 && !defined(OPENSSL_NO_TLSEXT)
+  if (ssl) {
+    return ssl->initial_ctx;
+  }
+#endif
+  return nullptr;
+}
+
+BioMethodUniquePtr OpenSSLUtils::newSocketBioMethod() {
+  BIO_METHOD* newmeth = nullptr;
+#if FOLLY_OPENSSL_IS_110
+  if (!(newmeth = BIO_meth_new(BIO_TYPE_SOCKET, "socket_bio_method"))) {
+    return nullptr;
+  }
+  auto meth = const_cast<BIO_METHOD*>(BIO_s_socket());
+  BIO_meth_set_create(newmeth, BIO_meth_get_create(meth));
+  BIO_meth_set_destroy(newmeth, BIO_meth_get_destroy(meth));
+  BIO_meth_set_ctrl(newmeth, BIO_meth_get_ctrl(meth));
+  BIO_meth_set_callback_ctrl(newmeth, BIO_meth_get_callback_ctrl(meth));
+  BIO_meth_set_read(newmeth, BIO_meth_get_read(meth));
+  BIO_meth_set_write(newmeth, BIO_meth_get_write(meth));
+  BIO_meth_set_gets(newmeth, BIO_meth_get_gets(meth));
+  BIO_meth_set_puts(newmeth, BIO_meth_get_puts(meth));
+#else
+  if (!(newmeth = (BIO_METHOD*)OPENSSL_malloc(sizeof(BIO_METHOD)))) {
+    return nullptr;
+  }
+  memcpy(newmeth, BIO_s_socket(), sizeof(BIO_METHOD));
+#endif
+
+  return BioMethodUniquePtr(newmeth);
+}
+
 bool OpenSSLUtils::setCustomBioReadMethod(
     BIO_METHOD* bioMeth,
     int (*meth)(BIO*, char*, int)) {
   bool ret = false;
-#if OPENSSL_IS_110
   ret = (BIO_meth_set_read(bioMeth, meth) == 1);
-#elif (defined(OPENSSL_IS_BORINGSSL) || OPENSSL_IS_101 || OPENSSL_IS_102)
-  bioMeth->bread = meth;
-  ret = true;
-#endif
-
   return ret;
 }
 
@@ -215,19 +252,13 @@ bool OpenSSLUtils::setCustomBioWriteMethod(
     BIO_METHOD* bioMeth,
     int (*meth)(BIO*, const char*, int)) {
   bool ret = false;
-#if OPENSSL_IS_110
   ret = (BIO_meth_set_write(bioMeth, meth) == 1);
-#elif (defined(OPENSSL_IS_BORINGSSL) || OPENSSL_IS_101 || OPENSSL_IS_102)
-  bioMeth->bwrite = meth;
-  ret = true;
-#endif
-
   return ret;
 }
 
 int OpenSSLUtils::getBioShouldRetryWrite(int r) {
   int ret = 0;
-#if defined(OPENSSL_IS_BORINGSSL)
+#ifdef OPENSSL_IS_BORINGSSL
   ret = boringssl_bio_fd_should_retry(r);
 #else
   ret = BIO_sock_should_retry(r);
@@ -236,7 +267,7 @@ int OpenSSLUtils::getBioShouldRetryWrite(int r) {
 }
 
 void OpenSSLUtils::setBioAppData(BIO* b, void* ptr) {
-#if defined(OPENSSL_IS_BORINGSSL)
+#ifdef OPENSSL_IS_BORINGSSL
   BIO_set_callback_arg(b, static_cast<char*>(ptr));
 #else
   BIO_set_app_data(b, ptr);
@@ -244,21 +275,13 @@ void OpenSSLUtils::setBioAppData(BIO* b, void* ptr) {
 }
 
 void* OpenSSLUtils::getBioAppData(BIO* b) {
-#if defined(OPENSSL_IS_BORINGSSL)
+#ifdef OPENSSL_IS_BORINGSSL
   return BIO_get_callback_arg(b);
 #else
   return BIO_get_app_data(b);
 #endif
 }
 
-void OpenSSLUtils::setCustomBioMethod(BIO* b, BIO_METHOD* meth) {
-#if defined(OPENSSL_IS_BORINGSSL)
-  b->method = meth;
-#else
-  BIO_set(b, meth);
-#endif
-}
-
 int OpenSSLUtils::getBioFd(BIO* b, int* fd) {
 #ifdef _WIN32
   int ret = portability::sockets::socket_to_fd((SOCKET)BIO_get_fd(b, fd));
@@ -273,7 +296,11 @@ int OpenSSLUtils::getBioFd(BIO* b, int* fd) {
 
 void OpenSSLUtils::setBioFd(BIO* b, int fd, int flags) {
 #ifdef _WIN32
-  SOCKET sock = portability::sockets::fd_to_socket(fd);
+  SOCKET socket = portability::sockets::fd_to_socket(fd);
+  // Internally OpenSSL uses this as an int for reasons completely
+  // beyond any form of sanity, so we do the cast ourselves to avoid
+  // the warnings that would be generated.
+  int sock = int(socket);
 #else
   int sock = fd;
 #endif
@@ -284,7 +311,7 @@ void OpenSSLUtils::setBioFd(BIO* b, int fd, int flags) {
 } // folly
 
 namespace {
-#if defined(OPENSSL_IS_BORINGSSL)
+#ifdef OPENSSL_IS_BORINGSSL
 
 static int boringssl_bio_fd_non_fatal_error(int err) {
   if (
@@ -320,11 +347,6 @@ static int boringssl_bio_fd_non_fatal_error(int err) {
 
 #if defined(OPENSSL_WINDOWS)
 
-#include <io.h>
-#pragma warning(push, 3)
-#include <windows.h>
-#pragma warning(pop)
-
 int boringssl_bio_fd_should_retry(int i) {
   if (i == -1) {
     return boringssl_bio_fd_non_fatal_error((int)GetLastError());
@@ -334,7 +356,6 @@ int boringssl_bio_fd_should_retry(int i) {
 
 #else // !OPENSSL_WINDOWS
 
-#include <unistd.h>
 int boringssl_bio_fd_should_retry(int i) {
   if (i == -1) {
     return boringssl_bio_fd_non_fatal_error(errno);