Fix copyright lines for Bits.h and move BitsBenchmark.cpp
[folly.git] / folly / ssl / OpenSSLHash.h
index 2a89e64b126cc17d0425497f356a08f832fd5812..0804c221858bf6442ee210a6c7d4b0ca7f995aa5 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 
 #pragma once
 
-#include <openssl/evp.h>
-#include <openssl/hmac.h>
-#include <openssl/sha.h>
-
 #include <folly/Range.h>
 #include <folly/io/IOBuf.h>
+#include <folly/portability/OpenSSL.h>
+#include <folly/ssl/OpenSSLPtrTypes.h>
 
 namespace folly {
 namespace ssl {
@@ -30,21 +28,31 @@ namespace ssl {
 /// These functions are not thread-safe unless you initialize OpenSSL.
 class OpenSSLHash {
  public:
-
   class Digest {
    public:
-    Digest() {
-      EVP_MD_CTX_init(&ctx_);
+    Digest() : ctx_(EVP_MD_CTX_new()) {}
+
+    Digest(const Digest& other) {
+      ctx_ = EvpMdCtxUniquePtr(EVP_MD_CTX_new());
+      if (other.md_ != nullptr) {
+        hash_init(other.md_);
+        check_libssl_result(
+            1, EVP_MD_CTX_copy_ex(ctx_.get(), other.ctx_.get()));
+      }
     }
-    ~Digest() {
-      EVP_MD_CTX_cleanup(&ctx_);
+
+    Digest& operator=(const Digest& other) {
+      this->~Digest();
+      return *new (this) Digest(other);
     }
+
     void hash_init(const EVP_MD* md) {
       md_ = md;
-      check_libssl_result(1, EVP_DigestInit_ex(&ctx_, md, nullptr));
+      check_libssl_result(1, EVP_DigestInit_ex(ctx_.get(), md, nullptr));
     }
     void hash_update(ByteRange data) {
-      check_libssl_result(1, EVP_DigestUpdate(&ctx_, data.data(), data.size()));
+      check_libssl_result(
+          1, EVP_DigestUpdate(ctx_.get(), data.data(), data.size()));
     }
     void hash_update(const IOBuf& data) {
       for (auto r : data) {
@@ -53,30 +61,25 @@ class OpenSSLHash {
     }
     void hash_final(MutableByteRange out) {
       const auto size = EVP_MD_size(md_);
-      check_out_size(size, out);
+      check_out_size(size_t(size), out);
       unsigned int len = 0;
-      check_libssl_result(1, EVP_DigestFinal_ex(&ctx_, out.data(), &len));
-      check_libssl_result(size, len);
+      check_libssl_result(1, EVP_DigestFinal_ex(ctx_.get(), out.data(), &len));
+      check_libssl_result(size, int(len));
       md_ = nullptr;
     }
+
    private:
     const EVP_MD* md_ = nullptr;
-    EVP_MD_CTX ctx_;
+    EvpMdCtxUniquePtr ctx_{nullptr};
   };
 
-  static void hash(
-      MutableByteRange out,
-      const EVP_MD* md,
-      ByteRange data) {
+  static void hash(MutableByteRange out, const EVP_MD* md, ByteRange data) {
     Digest hash;
     hash.hash_init(md);
     hash.hash_update(data);
     hash.hash_final(out);
   }
-  static void hash(
-      MutableByteRange out,
-      const EVP_MD* md,
-      const IOBuf& data) {
+  static void hash(MutableByteRange out, const EVP_MD* md, const IOBuf& data) {
     Digest hash;
     hash.hash_init(md);
     hash.hash_update(data);
@@ -97,19 +100,16 @@ class OpenSSLHash {
 
   class Hmac {
    public:
-    Hmac() {
-      HMAC_CTX_init(&ctx_);
-    }
-    ~Hmac() {
-      HMAC_CTX_cleanup(&ctx_);
-    }
+    Hmac() : ctx_(HMAC_CTX_new()) {}
+
     void hash_init(const EVP_MD* md, ByteRange key) {
       md_ = md;
       check_libssl_result(
-          1, HMAC_Init_ex(&ctx_, key.data(), int(key.size()), md_, nullptr));
+          1,
+          HMAC_Init_ex(ctx_.get(), key.data(), int(key.size()), md_, nullptr));
     }
     void hash_update(ByteRange data) {
-      check_libssl_result(1, HMAC_Update(&ctx_, data.data(), data.size()));
+      check_libssl_result(1, HMAC_Update(ctx_.get(), data.data(), data.size()));
     }
     void hash_update(const IOBuf& data) {
       for (auto r : data) {
@@ -118,22 +118,20 @@ class OpenSSLHash {
     }
     void hash_final(MutableByteRange out) {
       const auto size = EVP_MD_size(md_);
-      check_out_size(size, out);
+      check_out_size(size_t(size), out);
       unsigned int len = 0;
-      check_libssl_result(1, HMAC_Final(&ctx_, out.data(), &len));
+      check_libssl_result(1, HMAC_Final(ctx_.get(), out.data(), &len));
       check_libssl_result(size, int(len));
       md_ = nullptr;
     }
+
    private:
     const EVP_MD* md_ = nullptr;
-    HMAC_CTX ctx_;
+    HmacCtxUniquePtr ctx_{nullptr};
   };
 
-  static void hmac(
-      MutableByteRange out,
-      const EVP_MD* md,
-      ByteRange key,
-      ByteRange data) {
+  static void
+  hmac(MutableByteRange out, const EVP_MD* md, ByteRange key, ByteRange data) {
     Hmac hmac;
     hmac.hash_init(md, key);
     hmac.hash_update(data);
@@ -149,20 +147,18 @@ class OpenSSLHash {
     hmac.hash_update(data);
     hmac.hash_final(out);
   }
-  static void hmac_sha1(
-      MutableByteRange out, ByteRange key, ByteRange data) {
+  static void hmac_sha1(MutableByteRange out, ByteRange key, ByteRange data) {
     hmac(out, EVP_sha1(), key, data);
   }
-  static void hmac_sha1(
-      MutableByteRange out, ByteRange key, const IOBuf& data) {
+  static void
+  hmac_sha1(MutableByteRange out, ByteRange key, const IOBuf& data) {
     hmac(out, EVP_sha1(), key, data);
   }
-  static void hmac_sha256(
-      MutableByteRange out, ByteRange key, ByteRange data) {
+  static void hmac_sha256(MutableByteRange out, ByteRange key, ByteRange data) {
     hmac(out, EVP_sha256(), key, data);
   }
-  static void hmac_sha256(
-      MutableByteRange out, ByteRange key, const IOBuf& data) {
+  static void
+  hmac_sha256(MutableByteRange out, ByteRange key, const IOBuf& data) {
     hmac(out, EVP_sha256(), key, data);
   }
 
@@ -185,6 +181,5 @@ class OpenSSLHash {
   }
   [[noreturn]] static void check_libssl_result_throw();
 };
-
-}
-}
+} // namespace ssl
+} // namespace folly