65816b6d3fd6080f5daa96c9856781eb0cbd0bc7
[folly.git] / folly / concurrency / ConcurrentHashMap.h
1 /*
2  * Copyright 2017-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/Optional.h>
19 #include <folly/concurrency/detail/ConcurrentHashMap-detail.h>
20 #include <folly/experimental/hazptr/hazptr.h>
21 #include <atomic>
22 #include <mutex>
23
24 namespace folly {
25
26 /**
27  * Based on Java's ConcurrentHashMap
28  *
29  * Readers are always wait-free.
30  * Writers are sharded, but take a lock.
31  *
32  * The interface is as close to std::unordered_map as possible, but there
33  * are a handful of changes:
34  *
35  * * Iterators hold hazard pointers to the returned elements.  Elements can only
36  *   be accessed while Iterators are still valid!
37  *
38  * * Therefore operator[] and at() return copies, since they do not
39  *   return an iterator.  The returned value is const, to remind you
40  *   that changes do not affect the value in the map.
41  *
42  * * erase() calls the hash function, and may fail if the hash
43  *   function throws an exception.
44  *
45  * * clear() initializes new segments, and is not noexcept.
46  *
47  * * The interface adds assign_if_equal, since find() doesn't take a lock.
48  *
49  * * Only const version of find() is supported, and const iterators.
50  *   Mutation must use functions provided, like assign().
51  *
52  * * iteration iterates over all the buckets in the table, unlike
53  *   std::unordered_map which iterates over a linked list of elements.
54  *   If the table is sparse, this may be more expensive.
55  *
56  * * rehash policy is a power of two, using supplied factor.
57  *
58  * * Allocator must be stateless.
59  *
60  * * ValueTypes without copy constructors will work, but pessimize the
61  *   implementation.
62  *
63  * Comparisons:
64  *      Single-threaded performance is extremely similar to std::unordered_map.
65  *
66  *      Multithreaded performance beats anything except the lock-free
67  *           atomic maps (AtomicUnorderedMap, AtomicHashMap), BUT only
68  *           if you can perfectly size the atomic maps, and you don't
69  *           need erase().  If you don't know the size in advance or
70  *           your workload frequently calls erase(), this is the
71  *           better choice.
72  */
73
74 template <
75     typename KeyType,
76     typename ValueType,
77     typename HashFn = std::hash<KeyType>,
78     typename KeyEqual = std::equal_to<KeyType>,
79     typename Allocator = std::allocator<uint8_t>,
80     uint8_t ShardBits = 8,
81     template <typename> class Atom = std::atomic,
82     class Mutex = std::mutex>
83 class ConcurrentHashMap {
84   using SegmentT = detail::ConcurrentHashMapSegment<
85       KeyType,
86       ValueType,
87       ShardBits,
88       HashFn,
89       KeyEqual,
90       Allocator,
91       Atom,
92       Mutex>;
93   static constexpr uint64_t NumShards = (1 << ShardBits);
94   // Slightly higher than 1.0, in case hashing to shards isn't
95   // perfectly balanced, reserve(size) will still work without
96   // rehashing.
97   float load_factor_ = 1.05;
98
99  public:
100   class ConstIterator;
101
102   typedef KeyType key_type;
103   typedef ValueType mapped_type;
104   typedef std::pair<const KeyType, ValueType> value_type;
105   typedef std::size_t size_type;
106   typedef HashFn hasher;
107   typedef KeyEqual key_equal;
108   typedef ConstIterator const_iterator;
109
110   /*
111    * Construct a ConcurrentHashMap with 1 << ShardBits shards, size
112    * and max_size given.  Both size and max_size will be rounded up to
113    * the next power of two, if they are not already a power of two, so
114    * that we can index in to Shards efficiently.
115    *
116    * Insertion functions will throw bad_alloc if max_size is exceeded.
117    */
118   explicit ConcurrentHashMap(size_t size = 8, size_t max_size = 0) {
119     size_ = folly::nextPowTwo(size);
120     if (max_size != 0) {
121       max_size_ = folly::nextPowTwo(max_size);
122     }
123     CHECK(max_size_ == 0 || max_size_ >= size_);
124     for (uint64_t i = 0; i < NumShards; i++) {
125       segments_[i].store(nullptr, std::memory_order_relaxed);
126     }
127   }
128
129   ConcurrentHashMap(ConcurrentHashMap&& o) noexcept {
130     for (uint64_t i = 0; i < NumShards; i++) {
131       segments_[i].store(
132           o.segments_[i].load(std::memory_order_relaxed),
133           std::memory_order_relaxed);
134       o.segments_[i].store(nullptr, std::memory_order_relaxed);
135     }
136   }
137
138   ConcurrentHashMap& operator=(ConcurrentHashMap&& o) {
139     for (uint64_t i = 0; i < NumShards; i++) {
140       auto seg = segments_[i].load(std::memory_order_relaxed);
141       if (seg) {
142         seg->~SegmentT();
143         Allocator().deallocate((uint8_t*)seg, sizeof(SegmentT));
144       }
145       segments_[i].store(
146           o.segments_[i].load(std::memory_order_relaxed),
147           std::memory_order_relaxed);
148       o.segments_[i].store(nullptr, std::memory_order_relaxed);
149     }
150     return *this;
151   }
152
153   ~ConcurrentHashMap() {
154     for (uint64_t i = 0; i < NumShards; i++) {
155       auto seg = segments_[i].load(std::memory_order_relaxed);
156       if (seg) {
157         seg->~SegmentT();
158         Allocator().deallocate((uint8_t*)seg, sizeof(SegmentT));
159       }
160     }
161   }
162
163   bool empty() const noexcept {
164     for (uint64_t i = 0; i < NumShards; i++) {
165       auto seg = segments_[i].load(std::memory_order_acquire);
166       if (seg) {
167         if (!seg->empty()) {
168           return false;
169         }
170       }
171     }
172     return true;
173   }
174
175   ConstIterator find(const KeyType& k) const {
176     auto segment = pickSegment(k);
177     ConstIterator res(this, segment);
178     auto seg = segments_[segment].load(std::memory_order_acquire);
179     if (!seg || !seg->find(res.it_, k)) {
180       res.segment_ = NumShards;
181     }
182     return res;
183   }
184
185   ConstIterator cend() const noexcept {
186     return ConstIterator(NumShards);
187   }
188
189   ConstIterator cbegin() const noexcept {
190     return ConstIterator(this);
191   }
192
193   std::pair<ConstIterator, bool> insert(
194       std::pair<key_type, mapped_type>&& foo) {
195     auto segment = pickSegment(foo.first);
196     std::pair<ConstIterator, bool> res(
197         std::piecewise_construct,
198         std::forward_as_tuple(this, segment),
199         std::forward_as_tuple(false));
200     res.second = ensureSegment(segment)->insert(res.first.it_, std::move(foo));
201     return res;
202   }
203
204   std::pair<ConstIterator, bool> insert(const KeyType& k, const ValueType& v) {
205     auto segment = pickSegment(k);
206     std::pair<ConstIterator, bool> res(
207         std::piecewise_construct,
208         std::forward_as_tuple(this, segment),
209         std::forward_as_tuple(false));
210     res.second = ensureSegment(segment)->insert(res.first.it_, k, v);
211     return res;
212   }
213
214   template <typename... Args>
215   std::pair<ConstIterator, bool> try_emplace(const KeyType& k, Args&&... args) {
216     auto segment = pickSegment(k);
217     std::pair<ConstIterator, bool> res(
218         std::piecewise_construct,
219         std::forward_as_tuple(this, segment),
220         std::forward_as_tuple(false));
221     res.second = ensureSegment(segment)->try_emplace(
222         res.first.it_, k, std::forward<Args>(args)...);
223     return res;
224   }
225
226   template <typename... Args>
227   std::pair<ConstIterator, bool> emplace(Args&&... args) {
228     using Node = typename SegmentT::Node;
229     auto node = (Node*)Allocator().allocate(sizeof(Node));
230     new (node) Node(std::forward<Args>(args)...);
231     auto segment = pickSegment(node->getItem().first);
232     std::pair<ConstIterator, bool> res(
233         std::piecewise_construct,
234         std::forward_as_tuple(this, segment),
235         std::forward_as_tuple(false));
236     res.second = ensureSegment(segment)->emplace(
237         res.first.it_, node->getItem().first, node);
238     if (!res.second) {
239       node->~Node();
240       Allocator().deallocate((uint8_t*)node, sizeof(Node));
241     }
242     return res;
243   }
244
245   std::pair<ConstIterator, bool> insert_or_assign(
246       const KeyType& k,
247       const ValueType& v) {
248     auto segment = pickSegment(k);
249     std::pair<ConstIterator, bool> res(
250         std::piecewise_construct,
251         std::forward_as_tuple(this, segment),
252         std::forward_as_tuple(false));
253     res.second = ensureSegment(segment)->insert_or_assign(res.first.it_, k, v);
254     return res;
255   }
256
257   folly::Optional<ConstIterator> assign(const KeyType& k, const ValueType& v) {
258     auto segment = pickSegment(k);
259     ConstIterator res(this, segment);
260     auto seg = segments_[segment].load(std::memory_order_acquire);
261     if (!seg) {
262       return folly::Optional<ConstIterator>();
263     } else {
264       auto r = seg->assign(res.it_, k, v);
265       if (!r) {
266         return folly::Optional<ConstIterator>();
267       }
268     }
269     return res;
270   }
271
272   // Assign to desired if and only if key k is equal to expected
273   folly::Optional<ConstIterator> assign_if_equal(
274       const KeyType& k,
275       const ValueType& expected,
276       const ValueType& desired) {
277     auto segment = pickSegment(k);
278     ConstIterator res(this, segment);
279     auto seg = segments_[segment].load(std::memory_order_acquire);
280     if (!seg) {
281       return folly::Optional<ConstIterator>();
282     } else {
283       auto r = seg->assign_if_equal(res.it_, k, expected, desired);
284       if (!r) {
285         return folly::Optional<ConstIterator>();
286       }
287     }
288     return res;
289   }
290
291   // Copying wrappers around insert and find.
292   // Only available for copyable types.
293   const ValueType operator[](const KeyType& key) {
294     auto item = insert(key, ValueType());
295     return item.first->second;
296   }
297
298   const ValueType at(const KeyType& key) const {
299     auto item = find(key);
300     if (item == cend()) {
301       throw std::out_of_range("at(): value out of range");
302     }
303     return item->second;
304   }
305
306   // TODO update assign interface, operator[], at
307
308   size_type erase(const key_type& k) {
309     auto segment = pickSegment(k);
310     auto seg = segments_[segment].load(std::memory_order_acquire);
311     if (!seg) {
312       return 0;
313     } else {
314       return seg->erase(k);
315     }
316   }
317
318   // Calls the hash function, and therefore may throw.
319   ConstIterator erase(ConstIterator& pos) {
320     auto segment = pickSegment(pos->first);
321     ConstIterator res(this, segment);
322     res.next();
323     ensureSegment(segment)->erase(res.it_, pos.it_);
324     res.next(); // May point to segment end, and need to advance.
325     return res;
326   }
327
328   // NOT noexcept, initializes new shard segments vs.
329   void clear() {
330     for (uint64_t i = 0; i < NumShards; i++) {
331       auto seg = segments_[i].load(std::memory_order_acquire);
332       if (seg) {
333         seg->clear();
334       }
335     }
336   }
337
338   void reserve(size_t count) {
339     count = count >> ShardBits;
340     for (uint64_t i = 0; i < NumShards; i++) {
341       auto seg = segments_[i].load(std::memory_order_acquire);
342       if (seg) {
343         seg->rehash(count);
344       }
345     }
346   }
347
348   // This is a rolling size, and is not exact at any moment in time.
349   size_t size() const noexcept {
350     size_t res = 0;
351     for (uint64_t i = 0; i < NumShards; i++) {
352       auto seg = segments_[i].load(std::memory_order_acquire);
353       if (seg) {
354         res += seg->size();
355       }
356     }
357     return res;
358   }
359
360   float max_load_factor() const {
361     return load_factor_;
362   }
363
364   void max_load_factor(float factor) {
365     for (uint64_t i = 0; i < NumShards; i++) {
366       auto seg = segments_[i].load(std::memory_order_acquire);
367       if (seg) {
368         seg->max_load_factor(factor);
369       }
370     }
371   }
372
373   class ConstIterator {
374    public:
375     friend class ConcurrentHashMap;
376
377     const value_type& operator*() const {
378       return *it_;
379     }
380
381     const value_type* operator->() const {
382       return &*it_;
383     }
384
385     ConstIterator& operator++() {
386       it_++;
387       next();
388       return *this;
389     }
390
391     ConstIterator operator++(int) {
392       auto prev = *this;
393       ++*this;
394       return prev;
395     }
396
397     bool operator==(const ConstIterator& o) const {
398       return it_ == o.it_ && segment_ == o.segment_;
399     }
400
401     bool operator!=(const ConstIterator& o) const {
402       return !(*this == o);
403     }
404
405     ConstIterator& operator=(const ConstIterator& o) {
406       it_ = o.it_;
407       segment_ = o.segment_;
408       return *this;
409     }
410
411     ConstIterator(const ConstIterator& o) {
412       it_ = o.it_;
413       segment_ = o.segment_;
414     }
415
416     ConstIterator(const ConcurrentHashMap* parent, uint64_t segment)
417         : segment_(segment), parent_(parent) {}
418
419    private:
420     // cbegin iterator
421     explicit ConstIterator(const ConcurrentHashMap* parent)
422         : it_(parent->ensureSegment(0)->cbegin()),
423           segment_(0),
424           parent_(parent) {
425       // Always iterate to the first element, could be in any shard.
426       next();
427     }
428
429     // cend iterator
430     explicit ConstIterator(uint64_t shards) : it_(nullptr), segment_(shards) {}
431
432     void next() {
433       while (it_ == parent_->ensureSegment(segment_)->cend() &&
434              segment_ < parent_->NumShards) {
435         segment_++;
436         auto seg = parent_->segments_[segment_].load(std::memory_order_acquire);
437         if (segment_ < parent_->NumShards) {
438           if (!seg) {
439             continue;
440           }
441           it_ = seg->cbegin();
442         }
443       }
444     }
445
446     typename SegmentT::Iterator it_;
447     uint64_t segment_;
448     const ConcurrentHashMap* parent_;
449   };
450
451  private:
452   uint64_t pickSegment(const KeyType& k) const {
453     auto h = HashFn()(k);
454     // Use the lowest bits for our shard bits.
455     //
456     // This works well even if the hash function is biased towards the
457     // low bits: The sharding will happen in the segments_ instead of
458     // in the segment buckets, so we'll still get write sharding as
459     // well.
460     //
461     // Low-bit bias happens often for std::hash using small numbers,
462     // since the integer hash function is the identity function.
463     return h & (NumShards - 1);
464   }
465
466   SegmentT* ensureSegment(uint64_t i) const {
467     SegmentT* seg = segments_[i].load(std::memory_order_acquire);
468     if (!seg) {
469       SegmentT* newseg = (SegmentT*)Allocator().allocate(sizeof(SegmentT));
470       newseg = new (newseg)
471           SegmentT(size_ >> ShardBits, load_factor_, max_size_ >> ShardBits);
472       if (!segments_[i].compare_exchange_strong(seg, newseg)) {
473         // seg is updated with new value, delete ours.
474         newseg->~SegmentT();
475         Allocator().deallocate((uint8_t*)newseg, sizeof(SegmentT));
476       } else {
477         seg = newseg;
478       }
479     }
480     return seg;
481   }
482
483   mutable Atom<SegmentT*> segments_[NumShards];
484   size_t size_{0};
485   size_t max_size_{0};
486 };
487
488 } // namespace