Thread-safe version of loopKeepAlive()
[folly.git] / folly / Random.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/Random.h>
18
19 #include <atomic>
20 #include <mutex>
21 #include <random>
22 #include <array>
23
24 #include <glog/logging.h>
25 #include <folly/CallOnce.h>
26 #include <folly/File.h>
27 #include <folly/FileUtil.h>
28 #include <folly/ThreadLocal.h>
29 #include <folly/portability/SysTime.h>
30 #include <folly/portability/Unistd.h>
31
32 #ifdef _MSC_VER
33 # include <wincrypt.h>
34 #endif
35
36 namespace folly {
37
38 namespace {
39
40 void readRandomDevice(void* data, size_t size) {
41 #ifdef _MSC_VER
42   static folly::once_flag flag;
43   static HCRYPTPROV cryptoProv;
44   folly::call_once(flag, [&] {
45     if (!CryptAcquireContext(
46             &cryptoProv,
47             nullptr,
48             nullptr,
49             PROV_RSA_FULL,
50             CRYPT_VERIFYCONTEXT)) {
51       if (GetLastError() == NTE_BAD_KEYSET) {
52         // Mostly likely cause of this is that no key container
53         // exists yet, so try to create one.
54         PCHECK(CryptAcquireContext(
55             &cryptoProv, nullptr, nullptr, PROV_RSA_FULL, CRYPT_NEWKEYSET));
56       } else {
57         LOG(FATAL) << "Failed to acquire the default crypto context.";
58       }
59     }
60   });
61   CHECK(size <= std::numeric_limits<DWORD>::max());
62   PCHECK(CryptGenRandom(cryptoProv, (DWORD)size, (BYTE*)data));
63 #else
64   // Keep the random device open for the duration of the program.
65   static int randomFd = ::open("/dev/urandom", O_RDONLY);
66   PCHECK(randomFd >= 0);
67   auto bytesRead = readFull(randomFd, data, size);
68   PCHECK(bytesRead >= 0 && size_t(bytesRead) == size);
69 #endif
70 }
71
72 class BufferedRandomDevice {
73  public:
74   static constexpr size_t kDefaultBufferSize = 128;
75
76   explicit BufferedRandomDevice(size_t bufferSize = kDefaultBufferSize);
77
78   void get(void* data, size_t size) {
79     if (LIKELY(size <= remaining())) {
80       memcpy(data, ptr_, size);
81       ptr_ += size;
82     } else {
83       getSlow(static_cast<unsigned char*>(data), size);
84     }
85   }
86
87  private:
88   void getSlow(unsigned char* data, size_t size);
89
90   inline size_t remaining() const {
91     return buffer_.get() + bufferSize_ - ptr_;
92   }
93
94   const size_t bufferSize_;
95   std::unique_ptr<unsigned char[]> buffer_;
96   unsigned char* ptr_;
97 };
98
99 BufferedRandomDevice::BufferedRandomDevice(size_t bufferSize)
100   : bufferSize_(bufferSize),
101     buffer_(new unsigned char[bufferSize]),
102     ptr_(buffer_.get() + bufferSize) {  // refill on first use
103 }
104
105 void BufferedRandomDevice::getSlow(unsigned char* data, size_t size) {
106   DCHECK_GT(size, remaining());
107   if (size >= bufferSize_) {
108     // Just read directly.
109     readRandomDevice(data, size);
110     return;
111   }
112
113   size_t copied = remaining();
114   memcpy(data, ptr_, copied);
115   data += copied;
116   size -= copied;
117
118   // refill
119   readRandomDevice(buffer_.get(), bufferSize_);
120   ptr_ = buffer_.get();
121
122   memcpy(data, ptr_, size);
123   ptr_ += size;
124 }
125
126
127 }  // namespace
128
129 void Random::secureRandom(void* data, size_t size) {
130   static ThreadLocal<BufferedRandomDevice> bufferedRandomDevice;
131   bufferedRandomDevice->get(data, size);
132 }
133
134 class ThreadLocalPRNG::LocalInstancePRNG {
135  public:
136   LocalInstancePRNG() : rng(Random::create()) { }
137
138   Random::DefaultGenerator rng;
139 };
140
141 ThreadLocalPRNG::ThreadLocalPRNG() {
142   static folly::ThreadLocal<ThreadLocalPRNG::LocalInstancePRNG> localInstance;
143   local_ = localInstance.get();
144 }
145
146 uint32_t ThreadLocalPRNG::getImpl(LocalInstancePRNG* local) {
147   return local->rng();
148 }
149
150 }