ec1c8df3a8aa5a84466a1df18fa3e7378d7ca22c
[folly.git] / folly / concurrency / detail / ConcurrentHashMap-detail.h
1 /*
2  * Copyright 2017-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/experimental/hazptr/hazptr.h>
19 #include <atomic>
20 #include <mutex>
21
22 namespace folly {
23
24 namespace detail {
25
26 namespace concurrenthashmap {
27
28 // hazptr retire() that can use an allocator.
29 template <typename Allocator>
30 class HazptrDeleter {
31  public:
32   template <typename Node>
33   void operator()(Node* node) {
34     node->~Node();
35     Allocator().deallocate((uint8_t*)node, sizeof(Node));
36   }
37 };
38
39 template <
40     typename KeyType,
41     typename ValueType,
42     typename Allocator,
43     typename Enabled = void>
44 class ValueHolder {
45  public:
46   typedef std::pair<const KeyType, ValueType> value_type;
47
48   explicit ValueHolder(const ValueHolder& other) : item_(other.item_) {}
49
50   template <typename... Args>
51   ValueHolder(const KeyType& k, Args&&... args)
52       : item_(
53             std::piecewise_construct,
54             std::forward_as_tuple(k),
55             std::forward_as_tuple(std::forward<Args>(args)...)) {}
56   value_type& getItem() {
57     return item_;
58   }
59
60  private:
61   value_type item_;
62 };
63
64 // If the ValueType is not copy constructible, we can instead add
65 // an extra indirection.  Adds more allocations / deallocations and
66 // pulls in an extra cacheline.
67 template <typename KeyType, typename ValueType, typename Allocator>
68 class ValueHolder<
69     KeyType,
70     ValueType,
71     Allocator,
72     std::enable_if_t<!std::is_nothrow_copy_constructible<ValueType>::value>> {
73  public:
74   typedef std::pair<const KeyType, ValueType> value_type;
75
76   explicit ValueHolder(const ValueHolder& other) {
77     other.owned_ = false;
78     item_ = other.item_;
79   }
80
81   template <typename... Args>
82   ValueHolder(const KeyType& k, Args&&... args) {
83     item_ = (value_type*)Allocator().allocate(sizeof(value_type));
84     new (item_) value_type(
85         std::piecewise_construct,
86         std::forward_as_tuple(k),
87         std::forward_as_tuple(std::forward<Args>(args)...));
88   }
89
90   ~ValueHolder() {
91     if (owned_) {
92       item_->~value_type();
93       Allocator().deallocate((uint8_t*)item_, sizeof(value_type));
94     }
95   }
96
97   value_type& getItem() {
98     return *item_;
99   }
100
101  private:
102   value_type* item_;
103   mutable bool owned_{true};
104 };
105
106 template <
107     typename KeyType,
108     typename ValueType,
109     typename Allocator,
110     template <typename> class Atom = std::atomic>
111 class NodeT : public folly::hazptr::hazptr_obj_base<
112                   NodeT<KeyType, ValueType, Allocator, Atom>,
113                   concurrenthashmap::HazptrDeleter<Allocator>> {
114  public:
115   typedef std::pair<const KeyType, ValueType> value_type;
116
117   explicit NodeT(NodeT* other) : item_(other->item_) {}
118
119   template <typename... Args>
120   NodeT(const KeyType& k, Args&&... args)
121       : item_(k, std::forward<Args>(args)...) {}
122
123   /* Nodes are refcounted: If a node is retired() while a writer is
124      traversing the chain, the rest of the chain must remain valid
125      until all readers are finished.  This includes the shared tail
126      portion of the chain, as well as both old/new hash buckets that
127      may point to the same portion, and erased nodes may increase the
128      refcount */
129   void acquire() {
130     DCHECK(refcount_.load() != 0);
131     refcount_.fetch_add(1);
132   }
133   void release() {
134     if (refcount_.fetch_sub(1) == 1 /* was previously 1 */) {
135       this->retire(
136           folly::hazptr::default_hazptr_domain(),
137           concurrenthashmap::HazptrDeleter<Allocator>());
138     }
139   }
140   ~NodeT() {
141     auto next = next_.load(std::memory_order_acquire);
142     if (next) {
143       next->release();
144     }
145   }
146
147   value_type& getItem() {
148     return item_.getItem();
149   }
150   Atom<NodeT*> next_{nullptr};
151
152  private:
153   ValueHolder<KeyType, ValueType, Allocator> item_;
154   Atom<uint8_t> refcount_{1};
155 };
156
157 } // namespace concurrenthashmap
158
159 /* A Segment is a single shard of the ConcurrentHashMap.
160  * All writes take the lock, while readers are all wait-free.
161  * Readers always proceed in parallel with the single writer.
162  *
163  *
164  * Possible additional optimizations:
165  *
166  * * insert / erase could be lock / wait free.  Would need to be
167  *   careful that assign and rehash don't conflict (possibly with
168  *   reader/writer lock, or microlock per node or per bucket, etc).
169  *   Java 8 goes halfway, and and does lock per bucket, except for the
170  *   first item, that is inserted with a CAS (which is somewhat
171  *   specific to java having a lock per object)
172  *
173  * * I tried using trylock() and find() to warm the cache for insert()
174  *   and erase() similar to Java 7, but didn't have much luck.
175  *
176  * * We could order elements using split ordering, for faster rehash,
177  *   and no need to ever copy nodes.  Note that a full split ordering
178  *   including dummy nodes increases the memory usage by 2x, but we
179  *   could split the difference and still require a lock to set bucket
180  *   pointers.
181  *
182  * * hazptr acquire/release could be optimized more, in
183  *   single-threaded case, hazptr overhead is ~30% for a hot find()
184  *   loop.
185  */
186 template <
187     typename KeyType,
188     typename ValueType,
189     uint8_t ShardBits = 0,
190     typename HashFn = std::hash<KeyType>,
191     typename KeyEqual = std::equal_to<KeyType>,
192     typename Allocator = std::allocator<uint8_t>,
193     template <typename> class Atom = std::atomic,
194     class Mutex = std::mutex>
195 class FOLLY_ALIGNED(64) ConcurrentHashMapSegment {
196   enum class InsertType {
197     DOES_NOT_EXIST, // insert/emplace operations.  If key exists, return false.
198     MUST_EXIST, // assign operations.  If key does not exist, return false.
199     ANY, // insert_or_assign.
200     MATCH, // assign_if_equal (not in std).  For concurrent maps, a
201            // way to atomically change a value if equal to some other
202            // value.
203   };
204
205  public:
206   typedef KeyType key_type;
207   typedef ValueType mapped_type;
208   typedef std::pair<const KeyType, ValueType> value_type;
209   typedef std::size_t size_type;
210
211   using Node = concurrenthashmap::NodeT<KeyType, ValueType, Allocator, Atom>;
212   class Iterator;
213
214   ConcurrentHashMapSegment(
215       size_t initial_buckets,
216       float load_factor,
217       size_t max_size)
218       : load_factor_(load_factor) {
219     auto buckets = (Buckets*)Allocator().allocate(sizeof(Buckets));
220     initial_buckets = folly::nextPowTwo(initial_buckets);
221     if (max_size != 0) {
222       max_size_ = folly::nextPowTwo(max_size);
223     }
224     if (max_size_ > max_size) {
225       max_size_ >> 1;
226     }
227
228     CHECK(max_size_ == 0 || (folly::popcount(max_size_ - 1) + ShardBits <= 32));
229     new (buckets) Buckets(initial_buckets);
230     buckets_.store(buckets, std::memory_order_release);
231     load_factor_nodes_ = initial_buckets * load_factor_;
232   }
233
234   ~ConcurrentHashMapSegment() {
235     auto buckets = buckets_.load(std::memory_order_relaxed);
236     // We can delete and not retire() here, since users must have
237     // their own synchronization around destruction.
238     buckets->~Buckets();
239     Allocator().deallocate((uint8_t*)buckets, sizeof(Buckets));
240   }
241
242   size_t size() {
243     return size_;
244   }
245
246   bool empty() {
247     return size() == 0;
248   }
249
250   bool insert(Iterator& it, std::pair<key_type, mapped_type>&& foo) {
251     return insert(it, foo.first, foo.second);
252   }
253
254   bool insert(Iterator& it, const KeyType& k, const ValueType& v) {
255     auto node = (Node*)Allocator().allocate(sizeof(Node));
256     new (node) Node(k, v);
257     auto res = insert_internal(
258         it,
259         k,
260         InsertType::DOES_NOT_EXIST,
261         [](const ValueType&) { return false; },
262         node,
263         v);
264     if (!res) {
265       node->~Node();
266       Allocator().deallocate((uint8_t*)node, sizeof(Node));
267     }
268     return res;
269   }
270
271   template <typename... Args>
272   bool try_emplace(Iterator& it, const KeyType& k, Args&&... args) {
273     return insert_internal(
274         it,
275         k,
276         InsertType::DOES_NOT_EXIST,
277         [](const ValueType&) { return false; },
278         nullptr,
279         std::forward<Args>(args)...);
280   }
281
282   template <typename... Args>
283   bool emplace(Iterator& it, const KeyType& k, Node* node) {
284     return insert_internal(
285         it,
286         k,
287         InsertType::DOES_NOT_EXIST,
288         [](const ValueType&) { return false; },
289         node);
290   }
291
292   bool insert_or_assign(Iterator& it, const KeyType& k, const ValueType& v) {
293     return insert_internal(
294         it,
295         k,
296         InsertType::ANY,
297         [](const ValueType&) { return false; },
298         nullptr,
299         v);
300   }
301
302   bool assign(Iterator& it, const KeyType& k, const ValueType& v) {
303     auto node = (Node*)Allocator().allocate(sizeof(Node));
304     new (node) Node(k, v);
305     auto res = insert_internal(
306         it,
307         k,
308         InsertType::MUST_EXIST,
309         [](const ValueType&) { return false; },
310         node,
311         v);
312     if (!res) {
313       node->~Node();
314       Allocator().deallocate((uint8_t*)node, sizeof(Node));
315     }
316     return res;
317   }
318
319   bool assign_if_equal(
320       Iterator& it,
321       const KeyType& k,
322       const ValueType& expected,
323       const ValueType& desired) {
324     return insert_internal(
325         it,
326         k,
327         InsertType::MATCH,
328         [expected](const ValueType& v) { return v == expected; },
329         nullptr,
330         desired);
331   }
332
333   template <typename MatchFunc, typename... Args>
334   bool insert_internal(
335       Iterator& it,
336       const KeyType& k,
337       InsertType type,
338       MatchFunc match,
339       Node* cur,
340       Args&&... args) {
341     auto h = HashFn()(k);
342     std::unique_lock<Mutex> g(m_);
343
344     auto buckets = buckets_.load(std::memory_order_relaxed);
345     // Check for rehash needed for DOES_NOT_EXIST
346     if (size_ >= load_factor_nodes_ && type == InsertType::DOES_NOT_EXIST) {
347       if (max_size_ && size_ << 1 > max_size_) {
348         // Would exceed max size.
349         throw std::bad_alloc();
350       }
351       rehash(buckets->bucket_count_ << 1);
352       buckets = buckets_.load(std::memory_order_relaxed);
353     }
354
355     auto idx = getIdx(buckets, h);
356     auto head = &buckets->buckets_[idx];
357     auto node = head->load(std::memory_order_relaxed);
358     auto headnode = node;
359     auto prev = head;
360     it.buckets_hazptr_.reset(buckets);
361     while (node) {
362       // Is the key found?
363       if (KeyEqual()(k, node->getItem().first)) {
364         it.setNode(node, buckets, idx);
365         it.node_hazptr_.reset(node);
366         if (type == InsertType::MATCH) {
367           if (!match(node->getItem().second)) {
368             return false;
369           }
370         }
371         if (type == InsertType::DOES_NOT_EXIST) {
372           return false;
373         } else {
374           if (!cur) {
375             cur = (Node*)Allocator().allocate(sizeof(Node));
376             new (cur) Node(k, std::forward<Args>(args)...);
377           }
378           auto next = node->next_.load(std::memory_order_relaxed);
379           cur->next_.store(next, std::memory_order_relaxed);
380           if (next) {
381             next->acquire();
382           }
383           prev->store(cur, std::memory_order_release);
384           g.unlock();
385           // Release not under lock.
386           node->release();
387           return true;
388         }
389       }
390
391       prev = &node->next_;
392       node = node->next_.load(std::memory_order_relaxed);
393     }
394     if (type != InsertType::DOES_NOT_EXIST && type != InsertType::ANY) {
395       it.node_hazptr_.reset();
396       it.buckets_hazptr_.reset();
397       return false;
398     }
399     // Node not found, check for rehash on ANY
400     if (size_ >= load_factor_nodes_ && type == InsertType::ANY) {
401       if (max_size_ && size_ << 1 > max_size_) {
402         // Would exceed max size.
403         throw std::bad_alloc();
404       }
405       rehash(buckets->bucket_count_ << 1);
406
407       // Reload correct bucket.
408       buckets = buckets_.load(std::memory_order_relaxed);
409       it.buckets_hazptr_.reset(buckets);
410       idx = getIdx(buckets, h);
411       head = &buckets->buckets_[idx];
412       headnode = head->load(std::memory_order_relaxed);
413     }
414
415     // We found a slot to put the node.
416     size_++;
417     if (!cur) {
418       // InsertType::ANY
419       // OR DOES_NOT_EXIST, but only in the try_emplace case
420       DCHECK(type == InsertType::ANY || type == InsertType::DOES_NOT_EXIST);
421       cur = (Node*)Allocator().allocate(sizeof(Node));
422       new (cur) Node(k, std::forward<Args>(args)...);
423     }
424     cur->next_.store(headnode, std::memory_order_relaxed);
425     head->store(cur, std::memory_order_release);
426     it.setNode(cur, buckets, idx);
427     return true;
428   }
429
430   // Must hold lock.
431   void rehash(size_t bucket_count) {
432     auto buckets = buckets_.load(std::memory_order_relaxed);
433     auto newbuckets = (Buckets*)Allocator().allocate(sizeof(Buckets));
434     new (newbuckets) Buckets(bucket_count);
435
436     load_factor_nodes_ = bucket_count * load_factor_;
437
438     for (size_t i = 0; i < buckets->bucket_count_; i++) {
439       auto bucket = &buckets->buckets_[i];
440       auto node = bucket->load(std::memory_order_relaxed);
441       if (!node) {
442         continue;
443       }
444       auto h = HashFn()(node->getItem().first);
445       auto idx = getIdx(newbuckets, h);
446       // Reuse as long a chain as possible from the end.  Since the
447       // nodes don't have previous pointers, the longest last chain
448       // will be the same for both the previous hashmap and the new one,
449       // assuming all the nodes hash to the same bucket.
450       auto lastrun = node;
451       auto lastidx = idx;
452       auto count = 0;
453       auto last = node->next_.load(std::memory_order_relaxed);
454       for (; last != nullptr;
455            last = last->next_.load(std::memory_order_relaxed)) {
456         auto k = getIdx(newbuckets, HashFn()(last->getItem().first));
457         if (k != lastidx) {
458           lastidx = k;
459           lastrun = last;
460           count = 0;
461         }
462         count++;
463       }
464       // Set longest last run in new bucket, incrementing the refcount.
465       lastrun->acquire();
466       newbuckets->buckets_[lastidx].store(lastrun, std::memory_order_relaxed);
467       // Clone remaining nodes
468       for (; node != lastrun;
469            node = node->next_.load(std::memory_order_relaxed)) {
470         auto newnode = (Node*)Allocator().allocate(sizeof(Node));
471         new (newnode) Node(node);
472         auto k = getIdx(newbuckets, HashFn()(node->getItem().first));
473         auto prevhead = &newbuckets->buckets_[k];
474         newnode->next_.store(prevhead->load(std::memory_order_relaxed));
475         prevhead->store(newnode, std::memory_order_relaxed);
476       }
477     }
478
479     auto oldbuckets = buckets_.load(std::memory_order_relaxed);
480     buckets_.store(newbuckets, std::memory_order_release);
481     oldbuckets->retire(
482         folly::hazptr::default_hazptr_domain(),
483         concurrenthashmap::HazptrDeleter<Allocator>());
484   }
485
486   bool find(Iterator& res, const KeyType& k) {
487     folly::hazptr::hazptr_holder haznext;
488     auto h = HashFn()(k);
489     auto buckets = res.buckets_hazptr_.get_protected(buckets_);
490     auto idx = getIdx(buckets, h);
491     auto prev = &buckets->buckets_[idx];
492     auto node = res.node_hazptr_.get_protected(*prev);
493     while (node) {
494       if (KeyEqual()(k, node->getItem().first)) {
495         res.setNode(node, buckets, idx);
496         return true;
497       }
498       node = haznext.get_protected(node->next_);
499       haznext.swap(res.node_hazptr_);
500     }
501     return false;
502   }
503
504   // Listed separately because we need a prev pointer.
505   size_type erase(const key_type& key) {
506     return erase_internal(key, nullptr);
507   }
508
509   size_type erase_internal(const key_type& key, Iterator* iter) {
510     Node* node{nullptr};
511     auto h = HashFn()(key);
512     {
513       std::lock_guard<Mutex> g(m_);
514
515       auto buckets = buckets_.load(std::memory_order_relaxed);
516       auto idx = getIdx(buckets, h);
517       auto head = &buckets->buckets_[idx];
518       node = head->load(std::memory_order_relaxed);
519       Node* prev = nullptr;
520       auto headnode = node;
521       while (node) {
522         if (KeyEqual()(key, node->getItem().first)) {
523           auto next = node->next_.load(std::memory_order_relaxed);
524           if (next) {
525             next->acquire();
526           }
527           if (prev) {
528             prev->next_.store(next, std::memory_order_release);
529           } else {
530             // Must be head of list.
531             head->store(next, std::memory_order_release);
532           }
533
534           if (iter) {
535             iter->buckets_hazptr_.reset(buckets);
536             iter->setNode(
537                 node->next_.load(std::memory_order_acquire), buckets, idx);
538           }
539           size_--;
540           break;
541         }
542         prev = node;
543         node = node->next_.load(std::memory_order_relaxed);
544       }
545     }
546     // Delete the node while not under the lock.
547     if (node) {
548       node->release();
549       return 1;
550     }
551     DCHECK(!iter);
552
553     return 0;
554   }
555
556   // Unfortunately because we are reusing nodes on rehash, we can't
557   // have prev pointers in the bucket chain.  We have to start the
558   // search from the bucket.
559   //
560   // This is a small departure from standard stl containers: erase may
561   // throw if hash or key_eq functions throw.
562   void erase(Iterator& res, Iterator& pos) {
563     auto cnt = erase_internal(pos->first, &res);
564     DCHECK(cnt == 1);
565   }
566
567   void clear() {
568     auto buckets = buckets_.load(std::memory_order_relaxed);
569     auto newbuckets = (Buckets*)Allocator().allocate(sizeof(Buckets));
570     new (newbuckets) Buckets(buckets->bucket_count_);
571     {
572       std::lock_guard<Mutex> g(m_);
573       buckets_.store(newbuckets, std::memory_order_release);
574       size_ = 0;
575     }
576     buckets->retire(
577         folly::hazptr::default_hazptr_domain(),
578         concurrenthashmap::HazptrDeleter<Allocator>());
579   }
580
581   void max_load_factor(float factor) {
582     std::lock_guard<Mutex> g(m_);
583     load_factor_ = factor;
584     auto buckets = buckets_.load(std::memory_order_relaxed);
585     load_factor_nodes_ = buckets->bucket_count_ * load_factor_;
586   }
587
588   Iterator cbegin() {
589     Iterator res;
590     auto buckets = res.buckets_hazptr_.get_protected(buckets_);
591     res.setNode(nullptr, buckets, 0);
592     res.next();
593     return res;
594   }
595
596   Iterator cend() {
597     return Iterator(nullptr);
598   }
599
600   // Could be optimized to avoid an extra pointer dereference by
601   // allocating buckets_ at the same time.
602   class Buckets : public folly::hazptr::hazptr_obj_base<
603                       Buckets,
604                       concurrenthashmap::HazptrDeleter<Allocator>> {
605    public:
606     explicit Buckets(size_t count) : bucket_count_(count) {
607       buckets_ =
608           (Atom<Node*>*)Allocator().allocate(sizeof(Atom<Node*>) * count);
609       new (buckets_) Atom<Node*>[ count ];
610       for (size_t i = 0; i < count; i++) {
611         buckets_[i].store(nullptr, std::memory_order_relaxed);
612       }
613     }
614     ~Buckets() {
615       for (size_t i = 0; i < bucket_count_; i++) {
616         auto elem = buckets_[i].load(std::memory_order_relaxed);
617         if (elem) {
618           elem->release();
619         }
620       }
621       Allocator().deallocate(
622           (uint8_t*)buckets_, sizeof(Atom<Node*>) * bucket_count_);
623     }
624
625     size_t bucket_count_;
626     Atom<Node*>* buckets_{nullptr};
627   };
628
629  public:
630   class Iterator {
631    public:
632     FOLLY_ALWAYS_INLINE Iterator() {}
633     FOLLY_ALWAYS_INLINE explicit Iterator(std::nullptr_t)
634         : buckets_hazptr_(nullptr), node_hazptr_(nullptr) {}
635     FOLLY_ALWAYS_INLINE ~Iterator() {}
636
637     void setNode(Node* node, Buckets* buckets, uint64_t idx) {
638       node_ = node;
639       buckets_ = buckets;
640       idx_ = idx;
641     }
642
643     const value_type& operator*() const {
644       DCHECK(node_);
645       return node_->getItem();
646     }
647
648     const value_type* operator->() const {
649       DCHECK(node_);
650       return &(node_->getItem());
651     }
652
653     const Iterator& operator++() {
654       DCHECK(node_);
655       node_ = node_hazptr_.get_protected(node_->next_);
656       if (!node_) {
657         ++idx_;
658         next();
659       }
660       return *this;
661     }
662
663     void next() {
664       while (!node_) {
665         if (idx_ >= buckets_->bucket_count_) {
666           break;
667         }
668         DCHECK(buckets_);
669         DCHECK(buckets_->buckets_);
670         node_ = node_hazptr_.get_protected(buckets_->buckets_[idx_]);
671         if (node_) {
672           break;
673         }
674         ++idx_;
675       }
676     }
677
678     Iterator operator++(int) {
679       auto prev = *this;
680       ++*this;
681       return prev;
682     }
683
684     bool operator==(const Iterator& o) const {
685       return node_ == o.node_;
686     }
687
688     bool operator!=(const Iterator& o) const {
689       return !(*this == o);
690     }
691
692     Iterator& operator=(const Iterator& o) {
693       node_ = o.node_;
694       node_hazptr_.reset(node_);
695       idx_ = o.idx_;
696       buckets_ = o.buckets_;
697       buckets_hazptr_.reset(buckets_);
698       return *this;
699     }
700
701     /* implicit */ Iterator(const Iterator& o) {
702       node_ = o.node_;
703       node_hazptr_.reset(node_);
704       idx_ = o.idx_;
705       buckets_ = o.buckets_;
706       buckets_hazptr_.reset(buckets_);
707     }
708
709     /* implicit */ Iterator(Iterator&& o) noexcept
710         : buckets_hazptr_(std::move(o.buckets_hazptr_)),
711           node_hazptr_(std::move(o.node_hazptr_)) {
712       node_ = o.node_;
713       buckets_ = o.buckets_;
714       idx_ = o.idx_;
715     }
716
717     // These are accessed directly from the functions above
718     folly::hazptr::hazptr_holder buckets_hazptr_;
719     folly::hazptr::hazptr_holder node_hazptr_;
720
721    private:
722     Node* node_{nullptr};
723     Buckets* buckets_{nullptr};
724     uint64_t idx_;
725   };
726
727  private:
728   // Shards have already used low ShardBits of the hash.
729   // Shift it over to use fresh bits.
730   uint64_t getIdx(Buckets* buckets, size_t hash) {
731     return (hash >> ShardBits) & (buckets->bucket_count_ - 1);
732   }
733
734   float load_factor_;
735   size_t load_factor_nodes_;
736   size_t size_{0};
737   size_t max_size_{0};
738   Atom<Buckets*> buckets_{nullptr};
739   Mutex m_;
740 };
741 }
742 } // folly::detail namespace