f18a0b5118db89d1fbcd6b36da1b88601e680057
[folly.git] / folly / Random-inl.h
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef FOLLY_RANDOM_H_
18 #error This file may only be included from folly/Random.h
19 #endif
20
21 namespace folly {
22
23 namespace detail {
24
25 // Return the state size needed by RNG, expressed as a number of uint32_t
26 // integers. Specialized for all templates specified in the C++11 standard.
27 // For some (mersenne_twister_engine), this is exported as a state_size static
28 // data member; for others, the standard shows formulas.
29
30 template <class RNG> struct StateSize {
31   // A sane default.
32   static constexpr size_t value = 512;
33 };
34
35 template <class RNG>
36 constexpr size_t StateSize<RNG>::value;
37
38 template <class UIntType, UIntType a, UIntType c, UIntType m>
39 struct StateSize<std::linear_congruential_engine<UIntType, a, c, m>> {
40   // From the standard [rand.eng.lcong], this is ceil(log2(m) / 32) + 3,
41   // which is the same as ceil(ceil(log2(m) / 32) + 3, and
42   // ceil(log2(m)) <= std::numeric_limits<UIntType>::digits
43   static constexpr size_t value =
44     (std::numeric_limits<UIntType>::digits + 31) / 32 + 3;
45 };
46
47 template <class UIntType, UIntType a, UIntType c, UIntType m>
48 constexpr size_t
49 StateSize<std::linear_congruential_engine<UIntType, a, c, m>>::value;
50
51 template <class UIntType, size_t w, size_t n, size_t m, size_t r,
52           UIntType a, size_t u, UIntType d, size_t s,
53           UIntType b, size_t t,
54           UIntType c, size_t l, UIntType f>
55 struct StateSize<std::mersenne_twister_engine<UIntType, w, n, m, r,
56                                               a, u, d, s, b, t, c, l, f>> {
57   static constexpr size_t value =
58     std::mersenne_twister_engine<UIntType, w, n, m, r,
59                                  a, u, d, s, b, t, c, l, f>::state_size;
60 };
61
62 template <class UIntType, size_t w, size_t n, size_t m, size_t r,
63           UIntType a, size_t u, UIntType d, size_t s,
64           UIntType b, size_t t,
65           UIntType c, size_t l, UIntType f>
66 constexpr size_t
67 StateSize<std::mersenne_twister_engine<UIntType, w, n, m, r,
68                                        a, u, d, s, b, t, c, l, f>>::value;
69
70 #if FOLLY_USE_SIMD_PRNG
71
72 template <class UIntType, size_t m, size_t pos1, size_t sl1, size_t sl2,
73           size_t sr1, size_t sr2, uint32_t msk1, uint32_t msk2, uint32_t msk3,
74           uint32_t msk4, uint32_t parity1, uint32_t parity2, uint32_t parity3,
75           uint32_t parity4>
76 struct StateSize<__gnu_cxx::simd_fast_mersenne_twister_engine<
77     UIntType, m, pos1, sl1, sl2, sr1, sr2, msk1, msk2, msk3, msk4,
78     parity1, parity2, parity3, parity4>> {
79   static constexpr size_t value =
80     __gnu_cxx::simd_fast_mersenne_twister_engine<
81         UIntType, m, pos1, sl1, sl2, sr1, sr2,
82         msk1, msk2, msk3, msk4,
83         parity1, parity2, parity3, parity4>::state_size;
84 };
85
86 template <class UIntType, size_t m, size_t pos1, size_t sl1, size_t sl2,
87           size_t sr1, size_t sr2, uint32_t msk1, uint32_t msk2, uint32_t msk3,
88           uint32_t msk4, uint32_t parity1, uint32_t parity2, uint32_t parity3,
89           uint32_t parity4>
90 constexpr size_t
91 StateSize<__gnu_cxx::simd_fast_mersenne_twister_engine<
92     UIntType, m, pos1, sl1, sl2, sr1, sr2, msk1, msk2, msk3, msk4,
93     parity1, parity2, parity3, parity4>>::value;
94
95 #endif
96
97 template <class UIntType, size_t w, size_t s, size_t r>
98 struct StateSize<std::subtract_with_carry_engine<UIntType, w, s, r>> {
99   // [rand.eng.sub]: r * ceil(w / 32)
100   static constexpr size_t value = r * ((w + 31) / 32);
101 };
102
103 template <class UIntType, size_t w, size_t s, size_t r>
104 constexpr size_t
105 StateSize<std::subtract_with_carry_engine<UIntType, w, s, r>>::value;
106
107 template <class RNG>
108 struct SeedData {
109   SeedData() {
110     Random::secureRandom(seedData.begin(), seedData.size() * sizeof(uint32_t));
111   }
112
113   static constexpr size_t stateSize = StateSize<RNG>::value;
114   std::array<uint32_t, stateSize> seedData;
115 };
116
117 }  // namespace detail
118
119 template <class RNG>
120 void Random::seed(ValidRNG<RNG>& rng) {
121   detail::SeedData<RNG> sd;
122   std::seed_seq s(std::begin(sd.seedData), std::end(sd.seedData));
123   rng.seed(s);
124 }
125
126 template <class RNG>
127 auto Random::create() -> ValidRNG<RNG> {
128   detail::SeedData<RNG> sd;
129   std::seed_seq s(std::begin(sd.seedData), std::end(sd.seedData));
130   return RNG(s);
131 }
132
133 }  // namespaces