Add element construction/destruction hooks to IndexedMemPool
[folly.git] / folly / Random.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/Random.h>
18
19 #include <atomic>
20 #include <mutex>
21 #include <random>
22 #include <array>
23
24 #include <folly/CallOnce.h>
25 #include <folly/File.h>
26 #include <folly/FileUtil.h>
27 #include <folly/SingletonThreadLocal.h>
28 #include <folly/ThreadLocal.h>
29 #include <folly/portability/SysTime.h>
30 #include <folly/portability/Unistd.h>
31 #include <glog/logging.h>
32
33 #ifdef _MSC_VER
34 # include <wincrypt.h>
35 #endif
36
37 namespace folly {
38
39 namespace {
40
41 void readRandomDevice(void* data, size_t size) {
42 #ifdef _MSC_VER
43   static folly::once_flag flag;
44   static HCRYPTPROV cryptoProv;
45   folly::call_once(flag, [&] {
46     if (!CryptAcquireContext(
47             &cryptoProv,
48             nullptr,
49             nullptr,
50             PROV_RSA_FULL,
51             CRYPT_VERIFYCONTEXT)) {
52       if (GetLastError() == NTE_BAD_KEYSET) {
53         // Mostly likely cause of this is that no key container
54         // exists yet, so try to create one.
55         PCHECK(CryptAcquireContext(
56             &cryptoProv, nullptr, nullptr, PROV_RSA_FULL, CRYPT_NEWKEYSET));
57       } else {
58         LOG(FATAL) << "Failed to acquire the default crypto context.";
59       }
60     }
61   });
62   CHECK(size <= std::numeric_limits<DWORD>::max());
63   PCHECK(CryptGenRandom(cryptoProv, (DWORD)size, (BYTE*)data));
64 #else
65   // Keep the random device open for the duration of the program.
66   static int randomFd = ::open("/dev/urandom", O_RDONLY);
67   PCHECK(randomFd >= 0);
68   auto bytesRead = readFull(randomFd, data, size);
69   PCHECK(bytesRead >= 0 && size_t(bytesRead) == size);
70 #endif
71 }
72
73 class BufferedRandomDevice {
74  public:
75   static constexpr size_t kDefaultBufferSize = 128;
76
77   explicit BufferedRandomDevice(size_t bufferSize = kDefaultBufferSize);
78
79   void get(void* data, size_t size) {
80     if (LIKELY(size <= remaining())) {
81       memcpy(data, ptr_, size);
82       ptr_ += size;
83     } else {
84       getSlow(static_cast<unsigned char*>(data), size);
85     }
86   }
87
88  private:
89   void getSlow(unsigned char* data, size_t size);
90
91   inline size_t remaining() const {
92     return size_t(buffer_.get() + bufferSize_ - ptr_);
93   }
94
95   const size_t bufferSize_;
96   std::unique_ptr<unsigned char[]> buffer_;
97   unsigned char* ptr_;
98 };
99
100 BufferedRandomDevice::BufferedRandomDevice(size_t bufferSize)
101   : bufferSize_(bufferSize),
102     buffer_(new unsigned char[bufferSize]),
103     ptr_(buffer_.get() + bufferSize) {  // refill on first use
104 }
105
106 void BufferedRandomDevice::getSlow(unsigned char* data, size_t size) {
107   DCHECK_GT(size, remaining());
108   if (size >= bufferSize_) {
109     // Just read directly.
110     readRandomDevice(data, size);
111     return;
112   }
113
114   size_t copied = remaining();
115   memcpy(data, ptr_, copied);
116   data += copied;
117   size -= copied;
118
119   // refill
120   readRandomDevice(buffer_.get(), bufferSize_);
121   ptr_ = buffer_.get();
122
123   memcpy(data, ptr_, size);
124   ptr_ += size;
125 }
126
127 struct RandomTag {};
128
129 } // namespace
130
131 void Random::secureRandom(void* data, size_t size) {
132   static SingletonThreadLocal<BufferedRandomDevice, RandomTag>
133       bufferedRandomDevice;
134   bufferedRandomDevice.get().get(data, size);
135 }
136
137 class ThreadLocalPRNG::LocalInstancePRNG {
138  public:
139   LocalInstancePRNG() : rng(Random::create()) {}
140
141   Random::DefaultGenerator rng;
142 };
143
144 ThreadLocalPRNG::ThreadLocalPRNG() {
145   static SingletonThreadLocal<ThreadLocalPRNG::LocalInstancePRNG, RandomTag>
146       localInstancePRNG;
147   local_ = &localInstancePRNG.get();
148 }
149
150 uint32_t ThreadLocalPRNG::getImpl(LocalInstancePRNG* local) {
151   return local->rng();
152 }
153 }