Simplify the StateSize helper in Random
authorYedidya Feldblum <yfeldblum@fb.com>
Mon, 31 Jul 2017 00:58:29 +0000 (17:58 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 31 Jul 2017 01:22:11 +0000 (18:22 -0700)
Summary:
[Folly] Simplify the `StateSize` helper in `Random`.

* Using member type aliases rather than class constants means we can remove definitions.
* Partially specializing over all RNG types with `state_size` class constants means we can remove the `mersenne_twister` specializations, which have many template parameters and are a pain.

Reviewed By: Orvid

Differential Revision: D5525144

fbshipit-source-id: bc27f112ed0d9b55befe9dabe08c4d345a402435

folly/Random-inl.h
folly/Random.h
folly/test/RandomTest.cpp

index 405ef1ffa189798b86566b3329d1bdadccab3139..f2f3b4ba9397b462fa6267f71ddb66e1a0db6b23 100644 (file)
@@ -18,8 +18,6 @@
 #error This file may only be included from folly/Random.h
 #endif
 
-#include <array>
-
 namespace folly {
 
 namespace detail {
@@ -29,82 +27,35 @@ namespace detail {
 // For some (mersenne_twister_engine), this is exported as a state_size static
 // data member; for others, the standard shows formulas.
 
-template <class RNG> struct StateSize {
+template <class RNG, typename = void>
+struct StateSize {
   // A sane default.
-  static constexpr size_t value = 512;
+  using type = std::integral_constant<size_t, 512>;
 };
 
 template <class RNG>
-constexpr size_t StateSize<RNG>::value;
+struct StateSize<RNG, void_t<decltype(RNG::state_size)>> {
+  using type = std::integral_constant<size_t, RNG::state_size>;
+};
 
 template <class UIntType, UIntType a, UIntType c, UIntType m>
 struct StateSize<std::linear_congruential_engine<UIntType, a, c, m>> {
   // From the standard [rand.eng.lcong], this is ceil(log2(m) / 32) + 3,
   // which is the same as ceil(ceil(log2(m) / 32) + 3, and
   // ceil(log2(m)) <= std::numeric_limits<UIntType>::digits
-  static constexpr size_t value =
-    (std::numeric_limits<UIntType>::digits + 31) / 32 + 3;
-};
-
-template <class UIntType, UIntType a, UIntType c, UIntType m>
-constexpr size_t
-StateSize<std::linear_congruential_engine<UIntType, a, c, m>>::value;
-
-template <class UIntType, size_t w, size_t n, size_t m, size_t r,
-          UIntType a, size_t u, UIntType d, size_t s,
-          UIntType b, size_t t,
-          UIntType c, size_t l, UIntType f>
-struct StateSize<std::mersenne_twister_engine<UIntType, w, n, m, r,
-                                              a, u, d, s, b, t, c, l, f>> {
-  static constexpr size_t value =
-    std::mersenne_twister_engine<UIntType, w, n, m, r,
-                                 a, u, d, s, b, t, c, l, f>::state_size;
-};
-
-template <class UIntType, size_t w, size_t n, size_t m, size_t r,
-          UIntType a, size_t u, UIntType d, size_t s,
-          UIntType b, size_t t,
-          UIntType c, size_t l, UIntType f>
-constexpr size_t
-StateSize<std::mersenne_twister_engine<UIntType, w, n, m, r,
-                                       a, u, d, s, b, t, c, l, f>>::value;
-
-#if FOLLY_HAVE_EXTRANDOM_SFMT19937
-
-template <class UIntType, size_t m, size_t pos1, size_t sl1, size_t sl2,
-          size_t sr1, size_t sr2, uint32_t msk1, uint32_t msk2, uint32_t msk3,
-          uint32_t msk4, uint32_t parity1, uint32_t parity2, uint32_t parity3,
-          uint32_t parity4>
-struct StateSize<__gnu_cxx::simd_fast_mersenne_twister_engine<
-    UIntType, m, pos1, sl1, sl2, sr1, sr2, msk1, msk2, msk3, msk4,
-    parity1, parity2, parity3, parity4>> {
-  static constexpr size_t value =
-    __gnu_cxx::simd_fast_mersenne_twister_engine<
-        UIntType, m, pos1, sl1, sl2, sr1, sr2,
-        msk1, msk2, msk3, msk4,
-        parity1, parity2, parity3, parity4>::state_size;
+  using type = std::integral_constant<
+      size_t,
+      (std::numeric_limits<UIntType>::digits + 31) / 32 + 3>;
 };
 
-template <class UIntType, size_t m, size_t pos1, size_t sl1, size_t sl2,
-          size_t sr1, size_t sr2, uint32_t msk1, uint32_t msk2, uint32_t msk3,
-          uint32_t msk4, uint32_t parity1, uint32_t parity2, uint32_t parity3,
-          uint32_t parity4>
-constexpr size_t
-StateSize<__gnu_cxx::simd_fast_mersenne_twister_engine<
-    UIntType, m, pos1, sl1, sl2, sr1, sr2, msk1, msk2, msk3, msk4,
-    parity1, parity2, parity3, parity4>>::value;
-
-#endif
-
 template <class UIntType, size_t w, size_t s, size_t r>
 struct StateSize<std::subtract_with_carry_engine<UIntType, w, s, r>> {
   // [rand.eng.sub]: r * ceil(w / 32)
-  static constexpr size_t value = r * ((w + 31) / 32);
+  using type = std::integral_constant<size_t, r*((w + 31) / 32)>;
 };
 
-template <class UIntType, size_t w, size_t s, size_t r>
-constexpr size_t
-StateSize<std::subtract_with_carry_engine<UIntType, w, s, r>>::value;
+template <typename RNG>
+using StateSizeT = _t<StateSize<RNG>>;
 
 template <class RNG>
 struct SeedData {
@@ -112,7 +63,7 @@ struct SeedData {
     Random::secureRandom(seedData.data(), seedData.size() * sizeof(uint32_t));
   }
 
-  static constexpr size_t stateSize = StateSize<RNG>::value;
+  static constexpr size_t stateSize = StateSizeT<RNG>::value;
   std::array<uint32_t, stateSize> seedData;
 };
 
index 0042176b68d474c2a73d8e9a20d18bf49ae105a9..dbb16bf5676a972eac3df937274620e12d974c82 100644 (file)
 #pragma once
 #define FOLLY_RANDOM_H_
 
+#include <array>
 #include <cstdint>
 #include <random>
 #include <type_traits>
 
 #include <folly/Portability.h>
+#include <folly/Traits.h>
 
 #if FOLLY_HAVE_EXTRANDOM_SFMT19937
 #include <ext/random>
index 30f992515d22dfa2e31c29e623ffdda17ec57b7e..d36275750033eec1dc9a64b96612b8c7ddb2a3f1 100644 (file)
@@ -32,13 +32,13 @@ TEST(Random, StateSize) {
   using namespace folly::detail;
 
   // uint_fast32_t is uint64_t on x86_64, w00t
-  EXPECT_EQ(sizeof(uint_fast32_t) / 4 + 3,
-            StateSize<std::minstd_rand0>::value);
-  EXPECT_EQ(624, StateSize<std::mt19937>::value);
+  EXPECT_EQ(
+      sizeof(uint_fast32_t) / 4 + 3, StateSizeT<std::minstd_rand0>::value);
+  EXPECT_EQ(624, StateSizeT<std::mt19937>::value);
 #if FOLLY_HAVE_EXTRANDOM_SFMT19937
-  EXPECT_EQ(624, StateSize<__gnu_cxx::sfmt19937>::value);
+  EXPECT_EQ(624, StateSizeT<__gnu_cxx::sfmt19937>::value);
 #endif
-  EXPECT_EQ(24, StateSize<std::ranlux24_base>::value);
+  EXPECT_EQ(24, StateSizeT<std::ranlux24_base>::value);
 }
 
 TEST(Random, Simple) {