ConcurrentHashMap
authorDave Watson <davejwatson@fb.com>
Wed, 26 Jul 2017 16:41:45 +0000 (09:41 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 26 Jul 2017 16:51:34 +0000 (09:51 -0700)
Summary:
A ConcurrentHashMap with wait-free readers, as in Java's ConcurrentHashMap.

It's a pretty generic closed-addressing chaining hashtable, except find() uses two hazard pointers
to do hand-over-hand traversal of the list, so it never takes a lock.

On rehash, only the part of the chain that remains the same (i.e. is still hashed to the same bucket)
is reused, otherwise we have to allocate new nodes.

Reallocating nodes means we either have to copy the value_type, or add in an extra indirection
to access it.  Both are supported.

There's still a couple opportunities to squeeze some more perf out with optimistic loading
of nodes / cachelines, but I didn't go that far yet, it sill looks pretty good.

Reviewed By: davidtgoldblatt

Differential Revision: D5349966

fbshipit-source-id: 022e8adacd0ddd32b2a4563caa99c0c4878851d8

folly/concurrency/ConcurrentHashMap.h [new file with mode: 0644]
folly/concurrency/detail/ConcurrentHashMap-detail.h [new file with mode: 0644]
folly/concurrency/test/ConcurrentHashMapTest.cpp [new file with mode: 0644]
folly/experimental/hazptr/hazptr-impl.h

diff --git a/folly/concurrency/ConcurrentHashMap.h b/folly/concurrency/ConcurrentHashMap.h
new file mode 100644 (file)
index 0000000..0bc4d16
--- /dev/null
@@ -0,0 +1,488 @@
+/*
+ * Copyright 2017-present Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/Optional.h>
+#include <folly/concurrency/detail/ConcurrentHashMap-detail.h>
+#include <folly/experimental/hazptr/hazptr.h>
+#include <atomic>
+#include <mutex>
+
+namespace folly {
+
+/**
+ * Based on Java's ConcurrentHashMap
+ *
+ * Readers are always wait-free.
+ * Writers are sharded, but take a lock.
+ *
+ * The interface is as close to std::unordered_map as possible, but there
+ * are a handful of changes:
+ *
+ * * Iterators hold hazard pointers to the returned elements.  Elements can only
+ *   be accessed while Iterators are still valid!
+ *
+ * * Therefore operator[] and at() return copies, since they do not
+ *   return an iterator.  The returned value is const, to remind you
+ *   that changes do not affect the value in the map.
+ *
+ * * erase() calls the hash function, and may fail if the hash
+ *   function throws an exception.
+ *
+ * * clear() initializes new segments, and is not noexcept.
+ *
+ * * The interface adds assign_if_equal, since find() doesn't take a lock.
+ *
+ * * Only const version of find() is supported, and const iterators.
+ *   Mutation must use functions provided, like assign().
+ *
+ * * iteration iterates over all the buckets in the table, unlike
+ *   std::unordered_map which iterates over a linked list of elements.
+ *   If the table is sparse, this may be more expensive.
+ *
+ * * rehash policy is a power of two, using supplied factor.
+ *
+ * * Allocator must be stateless.
+ *
+ * * ValueTypes without copy constructors will work, but pessimize the
+ *   implementation.
+ *
+ * Comparisons:
+ *      Single-threaded performance is extremely similar to std::unordered_map.
+ *
+ *      Multithreaded performance beats anything except the lock-free
+ *           atomic maps (AtomicUnorderedMap, AtomicHashMap), BUT only
+ *           if you can perfectly size the atomic maps, and you don't
+ *           need erase().  If you don't know the size in advance or
+ *           your workload frequently calls erase(), this is the
+ *           better choice.
+ */
+
+template <
+    typename KeyType,
+    typename ValueType,
+    typename HashFn = std::hash<KeyType>,
+    typename KeyEqual = std::equal_to<KeyType>,
+    typename Allocator = std::allocator<uint8_t>,
+    uint8_t ShardBits = 8,
+    template <typename> class Atom = std::atomic,
+    class Mutex = std::mutex>
+class ConcurrentHashMap {
+  using SegmentT = detail::ConcurrentHashMapSegment<
+      KeyType,
+      ValueType,
+      ShardBits,
+      HashFn,
+      KeyEqual,
+      Allocator,
+      Atom,
+      Mutex>;
+  static constexpr uint64_t NumShards = (1 << ShardBits);
+  // Slightly higher than 1.0, in case hashing to shards isn't
+  // perfectly balanced, reserve(size) will still work without
+  // rehashing.
+  float load_factor_ = 1.05;
+
+ public:
+  class ConstIterator;
+
+  typedef KeyType key_type;
+  typedef ValueType mapped_type;
+  typedef std::pair<const KeyType, ValueType> value_type;
+  typedef std::size_t size_type;
+  typedef HashFn hasher;
+  typedef KeyEqual key_equal;
+  typedef ConstIterator const_iterator;
+
+  /*
+   * Construct a ConcurrentHashMap with 1 << ShardBits shards, size
+   * and max_size given.  Both size and max_size will be rounded up to
+   * the next power of two, if they are not already a power of two, so
+   * that we can index in to Shards efficiently.
+   *
+   * Insertion functions will throw bad_alloc if max_size is exceeded.
+   */
+  explicit ConcurrentHashMap(size_t size = 8, size_t max_size = 0) {
+    size_ = folly::nextPowTwo(size);
+    if (max_size != 0) {
+      max_size_ = folly::nextPowTwo(max_size);
+    }
+    CHECK(max_size_ == 0 || max_size_ >= size_);
+    for (uint64_t i = 0; i < NumShards; i++) {
+      segments_[i].store(nullptr, std::memory_order_relaxed);
+    }
+  }
+
+  ConcurrentHashMap(ConcurrentHashMap&& o) noexcept {
+    for (uint64_t i = 0; i < NumShards; i++) {
+      segments_[i].store(
+          o.segments_[i].load(std::memory_order_relaxed),
+          std::memory_order_relaxed);
+      o.segments_[i].store(nullptr, std::memory_order_relaxed);
+    }
+  }
+
+  ConcurrentHashMap& operator=(ConcurrentHashMap&& o) {
+    for (uint64_t i = 0; i < NumShards; i++) {
+      auto seg = segments_[i].load(std::memory_order_relaxed);
+      if (seg) {
+        seg->~SegmentT();
+        Allocator().deallocate((uint8_t*)seg, sizeof(SegmentT));
+      }
+      segments_[i].store(
+          o.segments_[i].load(std::memory_order_relaxed),
+          std::memory_order_relaxed);
+      o.segments_[i].store(nullptr, std::memory_order_relaxed);
+    }
+    return *this;
+  }
+
+  ~ConcurrentHashMap() {
+    for (uint64_t i = 0; i < NumShards; i++) {
+      auto seg = segments_[i].load(std::memory_order_relaxed);
+      if (seg) {
+        seg->~SegmentT();
+        Allocator().deallocate((uint8_t*)seg, sizeof(SegmentT));
+      }
+    }
+  }
+
+  bool empty() const noexcept {
+    for (uint64_t i = 0; i < NumShards; i++) {
+      auto seg = segments_[i].load(std::memory_order_acquire);
+      if (seg) {
+        if (!seg->empty()) {
+          return false;
+        }
+      }
+    }
+    return true;
+  }
+
+  ConstIterator find(const KeyType& k) const {
+    auto segment = pickSegment(k);
+    ConstIterator res(this, segment);
+    auto seg = segments_[segment].load(std::memory_order_acquire);
+    if (!seg || !seg->find(res.it_, k)) {
+      res.segment_ = NumShards;
+    }
+    return res;
+  }
+
+  ConstIterator cend() const noexcept {
+    return ConstIterator(NumShards);
+  }
+
+  ConstIterator cbegin() const noexcept {
+    return ConstIterator(this);
+  }
+
+  std::pair<ConstIterator, bool> insert(
+      std::pair<key_type, mapped_type>&& foo) {
+    auto segment = pickSegment(foo.first);
+    std::pair<ConstIterator, bool> res(
+        std::piecewise_construct,
+        std::forward_as_tuple(this, segment),
+        std::forward_as_tuple(false));
+    res.second = ensureSegment(segment)->insert(res.first.it_, std::move(foo));
+    return res;
+  }
+
+  std::pair<ConstIterator, bool> insert(const KeyType& k, const ValueType& v) {
+    auto segment = pickSegment(k);
+    std::pair<ConstIterator, bool> res(
+        std::piecewise_construct,
+        std::forward_as_tuple(this, segment),
+        std::forward_as_tuple(false));
+    res.second = ensureSegment(segment)->insert(res.first.it_, k, v);
+    return res;
+  }
+
+  template <typename... Args>
+  std::pair<ConstIterator, bool> try_emplace(const KeyType& k, Args&&... args) {
+    auto segment = pickSegment(k);
+    std::pair<ConstIterator, bool> res(
+        std::piecewise_construct,
+        std::forward_as_tuple(this, segment),
+        std::forward_as_tuple(false));
+    res.second = ensureSegment(segment)->try_emplace(
+        res.first.it_, k, std::forward<Args>(args)...);
+    return res;
+  }
+
+  template <typename... Args>
+  std::pair<ConstIterator, bool> emplace(Args&&... args) {
+    using Node = typename SegmentT::Node;
+    auto node = (Node*)Allocator().allocate(sizeof(Node));
+    new (node) Node(std::forward<Args>(args)...);
+    auto segment = pickSegment(node->getItem().first);
+    std::pair<ConstIterator, bool> res(
+        std::piecewise_construct,
+        std::forward_as_tuple(this, segment),
+        std::forward_as_tuple(false));
+    res.second = ensureSegment(segment)->emplace(
+        res.first.it_, node->getItem().first, node);
+    if (!res.second) {
+      node->~Node();
+      Allocator().deallocate((uint8_t*)node, sizeof(Node));
+    }
+    return res;
+  }
+
+  std::pair<ConstIterator, bool> insert_or_assign(
+      const KeyType& k,
+      const ValueType& v) {
+    auto segment = pickSegment(k);
+    std::pair<ConstIterator, bool> res(
+        std::piecewise_construct,
+        std::forward_as_tuple(this, segment),
+        std::forward_as_tuple(false));
+    res.second = ensureSegment(segment)->insert_or_assign(res.first.it_, k, v);
+    return res;
+  }
+
+  folly::Optional<ConstIterator> assign(const KeyType& k, const ValueType& v) {
+    auto segment = pickSegment(k);
+    ConstIterator res(this, segment);
+    auto seg = segments_[segment].load(std::memory_order_acquire);
+    if (!seg) {
+      return folly::Optional<ConstIterator>();
+    } else {
+      auto r = seg->assign(res.it_, k, v);
+      if (!r) {
+        return folly::Optional<ConstIterator>();
+      }
+    }
+    return res;
+  }
+
+  // Assign to desired if and only if key k is equal to expected
+  folly::Optional<ConstIterator> assign_if_equal(
+      const KeyType& k,
+      const ValueType& expected,
+      const ValueType& desired) {
+    auto segment = pickSegment(k);
+    ConstIterator res(this, segment);
+    auto seg = segments_[segment].load(std::memory_order_acquire);
+    if (!seg) {
+      return folly::Optional<ConstIterator>();
+    } else {
+      auto r = seg->assign_if_equal(res.it_, k, expected, desired);
+      if (!r) {
+        return folly::Optional<ConstIterator>();
+      }
+    }
+    return res;
+  }
+
+  // Copying wrappers around insert and find.
+  // Only available for copyable types.
+  const ValueType operator[](const KeyType& key) {
+    auto item = insert(key, ValueType());
+    return item.first->second;
+  }
+
+  const ValueType at(const KeyType& key) const {
+    auto item = find(key);
+    if (item == cend()) {
+      throw std::out_of_range("at(): value out of range");
+    }
+    return item->second;
+  }
+
+  // TODO update assign interface, operator[], at
+
+  size_type erase(const key_type& k) {
+    auto segment = pickSegment(k);
+    auto seg = segments_[segment].load(std::memory_order_acquire);
+    if (!seg) {
+      return 0;
+    } else {
+      return seg->erase(k);
+    }
+  }
+
+  // Calls the hash function, and therefore may throw.
+  ConstIterator erase(ConstIterator& pos) {
+    auto segment = pickSegment(pos->first);
+    ConstIterator res(this, segment);
+    res.next();
+    ensureSegment(segment)->erase(res.it_, pos.it_);
+    res.next(); // May point to segment end, and need to advance.
+    return res;
+  }
+
+  // NOT noexcept, initializes new shard segments vs.
+  void clear() {
+    for (uint64_t i = 0; i < NumShards; i++) {
+      auto seg = segments_[i].load(std::memory_order_acquire);
+      if (seg) {
+        seg->clear();
+      }
+    }
+  }
+
+  void reserve(size_t count) {
+    count = count >> ShardBits;
+    for (uint64_t i = 0; i < NumShards; i++) {
+      auto seg = segments_[i].load(std::memory_order_acquire);
+      if (seg) {
+        seg->rehash(count);
+      }
+    }
+  }
+
+  // This is a rolling size, and is not exact at any moment in time.
+  size_t size() const noexcept {
+    size_t res = 0;
+    for (uint64_t i = 0; i < NumShards; i++) {
+      auto seg = segments_[i].load(std::memory_order_acquire);
+      if (seg) {
+        res += seg->size();
+      }
+    }
+    return res;
+  }
+
+  float max_load_factor() const {
+    return load_factor_;
+  }
+
+  void max_load_factor(float factor) {
+    for (uint64_t i = 0; i < NumShards; i++) {
+      auto seg = segments_[i].load(std::memory_order_acquire);
+      if (seg) {
+        seg->max_load_factor(factor);
+      }
+    }
+  }
+
+  class ConstIterator {
+   public:
+    friend class ConcurrentHashMap;
+
+    const value_type& operator*() const {
+      return *it_;
+    }
+
+    const value_type* operator->() const {
+      return &*it_;
+    }
+
+    ConstIterator& operator++() {
+      it_++;
+      next();
+      return *this;
+    }
+
+    ConstIterator operator++(int) {
+      auto prev = *this;
+      ++*this;
+      return prev;
+    }
+
+    bool operator==(const ConstIterator& o) const {
+      return it_ == o.it_ && segment_ == o.segment_;
+    }
+
+    bool operator!=(const ConstIterator& o) const {
+      return !(*this == o);
+    }
+
+    ConstIterator& operator=(const ConstIterator& o) {
+      it_ = o.it_;
+      segment_ = o.segment_;
+      return *this;
+    }
+
+    ConstIterator(const ConstIterator& o) {
+      it_ = o.it_;
+      segment_ = o.segment_;
+    }
+
+    ConstIterator(const ConcurrentHashMap* parent, uint64_t segment)
+        : segment_(segment), parent_(parent) {}
+
+   private:
+    // cbegin iterator
+    explicit ConstIterator(const ConcurrentHashMap* parent)
+        : it_(parent->ensureSegment(0)->cbegin()),
+          segment_(0),
+          parent_(parent) {
+      // Always iterate to the first element, could be in any shard.
+      next();
+    }
+
+    // cend iterator
+    explicit ConstIterator(uint64_t shards) : it_(nullptr), segment_(shards) {}
+
+    void next() {
+      while (it_ == parent_->ensureSegment(segment_)->cend() &&
+             segment_ < parent_->NumShards) {
+        segment_++;
+        auto seg = parent_->segments_[segment_].load(std::memory_order_acquire);
+        if (segment_ < parent_->NumShards) {
+          if (!seg) {
+            continue;
+          }
+          it_ = seg->cbegin();
+        }
+      }
+    }
+
+    typename SegmentT::Iterator it_;
+    uint64_t segment_;
+    const ConcurrentHashMap* parent_;
+  };
+
+ private:
+  uint64_t pickSegment(const KeyType& k) const {
+    auto h = HashFn()(k);
+    // Use the lowest bits for our shard bits.
+    //
+    // This works well even if the hash function is biased towards the
+    // low bits: The sharding will happen in the segments_ instead of
+    // in the segment buckets, so we'll still get write sharding as
+    // well.
+    //
+    // Low-bit bias happens often for std::hash using small numbers,
+    // since the integer hash function is the identity function.
+    return h & (NumShards - 1);
+  }
+
+  SegmentT* ensureSegment(uint64_t i) const {
+    auto seg = segments_[i].load(std::memory_order_acquire);
+    if (!seg) {
+      auto newseg = (SegmentT*)Allocator().allocate(sizeof(SegmentT));
+      new (newseg)
+          SegmentT(size_ >> ShardBits, load_factor_, max_size_ >> ShardBits);
+      if (!segments_[i].compare_exchange_strong(seg, newseg)) {
+        // seg is updated with new value, delete ours.
+        newseg->~SegmentT();
+        Allocator().deallocate((uint8_t*)newseg, sizeof(SegmentT));
+      } else {
+        seg = newseg;
+      }
+    }
+    return seg;
+  }
+
+  mutable Atom<SegmentT*> segments_[NumShards];
+  size_t size_{0};
+  size_t max_size_{0};
+};
+
+} // namespace
diff --git a/folly/concurrency/detail/ConcurrentHashMap-detail.h b/folly/concurrency/detail/ConcurrentHashMap-detail.h
new file mode 100644 (file)
index 0000000..ec1c8df
--- /dev/null
@@ -0,0 +1,742 @@
+/*
+ * Copyright 2017-present Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/experimental/hazptr/hazptr.h>
+#include <atomic>
+#include <mutex>
+
+namespace folly {
+
+namespace detail {
+
+namespace concurrenthashmap {
+
+// hazptr retire() that can use an allocator.
+template <typename Allocator>
+class HazptrDeleter {
+ public:
+  template <typename Node>
+  void operator()(Node* node) {
+    node->~Node();
+    Allocator().deallocate((uint8_t*)node, sizeof(Node));
+  }
+};
+
+template <
+    typename KeyType,
+    typename ValueType,
+    typename Allocator,
+    typename Enabled = void>
+class ValueHolder {
+ public:
+  typedef std::pair<const KeyType, ValueType> value_type;
+
+  explicit ValueHolder(const ValueHolder& other) : item_(other.item_) {}
+
+  template <typename... Args>
+  ValueHolder(const KeyType& k, Args&&... args)
+      : item_(
+            std::piecewise_construct,
+            std::forward_as_tuple(k),
+            std::forward_as_tuple(std::forward<Args>(args)...)) {}
+  value_type& getItem() {
+    return item_;
+  }
+
+ private:
+  value_type item_;
+};
+
+// If the ValueType is not copy constructible, we can instead add
+// an extra indirection.  Adds more allocations / deallocations and
+// pulls in an extra cacheline.
+template <typename KeyType, typename ValueType, typename Allocator>
+class ValueHolder<
+    KeyType,
+    ValueType,
+    Allocator,
+    std::enable_if_t<!std::is_nothrow_copy_constructible<ValueType>::value>> {
+ public:
+  typedef std::pair<const KeyType, ValueType> value_type;
+
+  explicit ValueHolder(const ValueHolder& other) {
+    other.owned_ = false;
+    item_ = other.item_;
+  }
+
+  template <typename... Args>
+  ValueHolder(const KeyType& k, Args&&... args) {
+    item_ = (value_type*)Allocator().allocate(sizeof(value_type));
+    new (item_) value_type(
+        std::piecewise_construct,
+        std::forward_as_tuple(k),
+        std::forward_as_tuple(std::forward<Args>(args)...));
+  }
+
+  ~ValueHolder() {
+    if (owned_) {
+      item_->~value_type();
+      Allocator().deallocate((uint8_t*)item_, sizeof(value_type));
+    }
+  }
+
+  value_type& getItem() {
+    return *item_;
+  }
+
+ private:
+  value_type* item_;
+  mutable bool owned_{true};
+};
+
+template <
+    typename KeyType,
+    typename ValueType,
+    typename Allocator,
+    template <typename> class Atom = std::atomic>
+class NodeT : public folly::hazptr::hazptr_obj_base<
+                  NodeT<KeyType, ValueType, Allocator, Atom>,
+                  concurrenthashmap::HazptrDeleter<Allocator>> {
+ public:
+  typedef std::pair<const KeyType, ValueType> value_type;
+
+  explicit NodeT(NodeT* other) : item_(other->item_) {}
+
+  template <typename... Args>
+  NodeT(const KeyType& k, Args&&... args)
+      : item_(k, std::forward<Args>(args)...) {}
+
+  /* Nodes are refcounted: If a node is retired() while a writer is
+     traversing the chain, the rest of the chain must remain valid
+     until all readers are finished.  This includes the shared tail
+     portion of the chain, as well as both old/new hash buckets that
+     may point to the same portion, and erased nodes may increase the
+     refcount */
+  void acquire() {
+    DCHECK(refcount_.load() != 0);
+    refcount_.fetch_add(1);
+  }
+  void release() {
+    if (refcount_.fetch_sub(1) == 1 /* was previously 1 */) {
+      this->retire(
+          folly::hazptr::default_hazptr_domain(),
+          concurrenthashmap::HazptrDeleter<Allocator>());
+    }
+  }
+  ~NodeT() {
+    auto next = next_.load(std::memory_order_acquire);
+    if (next) {
+      next->release();
+    }
+  }
+
+  value_type& getItem() {
+    return item_.getItem();
+  }
+  Atom<NodeT*> next_{nullptr};
+
+ private:
+  ValueHolder<KeyType, ValueType, Allocator> item_;
+  Atom<uint8_t> refcount_{1};
+};
+
+} // namespace concurrenthashmap
+
+/* A Segment is a single shard of the ConcurrentHashMap.
+ * All writes take the lock, while readers are all wait-free.
+ * Readers always proceed in parallel with the single writer.
+ *
+ *
+ * Possible additional optimizations:
+ *
+ * * insert / erase could be lock / wait free.  Would need to be
+ *   careful that assign and rehash don't conflict (possibly with
+ *   reader/writer lock, or microlock per node or per bucket, etc).
+ *   Java 8 goes halfway, and and does lock per bucket, except for the
+ *   first item, that is inserted with a CAS (which is somewhat
+ *   specific to java having a lock per object)
+ *
+ * * I tried using trylock() and find() to warm the cache for insert()
+ *   and erase() similar to Java 7, but didn't have much luck.
+ *
+ * * We could order elements using split ordering, for faster rehash,
+ *   and no need to ever copy nodes.  Note that a full split ordering
+ *   including dummy nodes increases the memory usage by 2x, but we
+ *   could split the difference and still require a lock to set bucket
+ *   pointers.
+ *
+ * * hazptr acquire/release could be optimized more, in
+ *   single-threaded case, hazptr overhead is ~30% for a hot find()
+ *   loop.
+ */
+template <
+    typename KeyType,
+    typename ValueType,
+    uint8_t ShardBits = 0,
+    typename HashFn = std::hash<KeyType>,
+    typename KeyEqual = std::equal_to<KeyType>,
+    typename Allocator = std::allocator<uint8_t>,
+    template <typename> class Atom = std::atomic,
+    class Mutex = std::mutex>
+class FOLLY_ALIGNED(64) ConcurrentHashMapSegment {
+  enum class InsertType {
+    DOES_NOT_EXIST, // insert/emplace operations.  If key exists, return false.
+    MUST_EXIST, // assign operations.  If key does not exist, return false.
+    ANY, // insert_or_assign.
+    MATCH, // assign_if_equal (not in std).  For concurrent maps, a
+           // way to atomically change a value if equal to some other
+           // value.
+  };
+
+ public:
+  typedef KeyType key_type;
+  typedef ValueType mapped_type;
+  typedef std::pair<const KeyType, ValueType> value_type;
+  typedef std::size_t size_type;
+
+  using Node = concurrenthashmap::NodeT<KeyType, ValueType, Allocator, Atom>;
+  class Iterator;
+
+  ConcurrentHashMapSegment(
+      size_t initial_buckets,
+      float load_factor,
+      size_t max_size)
+      : load_factor_(load_factor) {
+    auto buckets = (Buckets*)Allocator().allocate(sizeof(Buckets));
+    initial_buckets = folly::nextPowTwo(initial_buckets);
+    if (max_size != 0) {
+      max_size_ = folly::nextPowTwo(max_size);
+    }
+    if (max_size_ > max_size) {
+      max_size_ >> 1;
+    }
+
+    CHECK(max_size_ == 0 || (folly::popcount(max_size_ - 1) + ShardBits <= 32));
+    new (buckets) Buckets(initial_buckets);
+    buckets_.store(buckets, std::memory_order_release);
+    load_factor_nodes_ = initial_buckets * load_factor_;
+  }
+
+  ~ConcurrentHashMapSegment() {
+    auto buckets = buckets_.load(std::memory_order_relaxed);
+    // We can delete and not retire() here, since users must have
+    // their own synchronization around destruction.
+    buckets->~Buckets();
+    Allocator().deallocate((uint8_t*)buckets, sizeof(Buckets));
+  }
+
+  size_t size() {
+    return size_;
+  }
+
+  bool empty() {
+    return size() == 0;
+  }
+
+  bool insert(Iterator& it, std::pair<key_type, mapped_type>&& foo) {
+    return insert(it, foo.first, foo.second);
+  }
+
+  bool insert(Iterator& it, const KeyType& k, const ValueType& v) {
+    auto node = (Node*)Allocator().allocate(sizeof(Node));
+    new (node) Node(k, v);
+    auto res = insert_internal(
+        it,
+        k,
+        InsertType::DOES_NOT_EXIST,
+        [](const ValueType&) { return false; },
+        node,
+        v);
+    if (!res) {
+      node->~Node();
+      Allocator().deallocate((uint8_t*)node, sizeof(Node));
+    }
+    return res;
+  }
+
+  template <typename... Args>
+  bool try_emplace(Iterator& it, const KeyType& k, Args&&... args) {
+    return insert_internal(
+        it,
+        k,
+        InsertType::DOES_NOT_EXIST,
+        [](const ValueType&) { return false; },
+        nullptr,
+        std::forward<Args>(args)...);
+  }
+
+  template <typename... Args>
+  bool emplace(Iterator& it, const KeyType& k, Node* node) {
+    return insert_internal(
+        it,
+        k,
+        InsertType::DOES_NOT_EXIST,
+        [](const ValueType&) { return false; },
+        node);
+  }
+
+  bool insert_or_assign(Iterator& it, const KeyType& k, const ValueType& v) {
+    return insert_internal(
+        it,
+        k,
+        InsertType::ANY,
+        [](const ValueType&) { return false; },
+        nullptr,
+        v);
+  }
+
+  bool assign(Iterator& it, const KeyType& k, const ValueType& v) {
+    auto node = (Node*)Allocator().allocate(sizeof(Node));
+    new (node) Node(k, v);
+    auto res = insert_internal(
+        it,
+        k,
+        InsertType::MUST_EXIST,
+        [](const ValueType&) { return false; },
+        node,
+        v);
+    if (!res) {
+      node->~Node();
+      Allocator().deallocate((uint8_t*)node, sizeof(Node));
+    }
+    return res;
+  }
+
+  bool assign_if_equal(
+      Iterator& it,
+      const KeyType& k,
+      const ValueType& expected,
+      const ValueType& desired) {
+    return insert_internal(
+        it,
+        k,
+        InsertType::MATCH,
+        [expected](const ValueType& v) { return v == expected; },
+        nullptr,
+        desired);
+  }
+
+  template <typename MatchFunc, typename... Args>
+  bool insert_internal(
+      Iterator& it,
+      const KeyType& k,
+      InsertType type,
+      MatchFunc match,
+      Node* cur,
+      Args&&... args) {
+    auto h = HashFn()(k);
+    std::unique_lock<Mutex> g(m_);
+
+    auto buckets = buckets_.load(std::memory_order_relaxed);
+    // Check for rehash needed for DOES_NOT_EXIST
+    if (size_ >= load_factor_nodes_ && type == InsertType::DOES_NOT_EXIST) {
+      if (max_size_ && size_ << 1 > max_size_) {
+        // Would exceed max size.
+        throw std::bad_alloc();
+      }
+      rehash(buckets->bucket_count_ << 1);
+      buckets = buckets_.load(std::memory_order_relaxed);
+    }
+
+    auto idx = getIdx(buckets, h);
+    auto head = &buckets->buckets_[idx];
+    auto node = head->load(std::memory_order_relaxed);
+    auto headnode = node;
+    auto prev = head;
+    it.buckets_hazptr_.reset(buckets);
+    while (node) {
+      // Is the key found?
+      if (KeyEqual()(k, node->getItem().first)) {
+        it.setNode(node, buckets, idx);
+        it.node_hazptr_.reset(node);
+        if (type == InsertType::MATCH) {
+          if (!match(node->getItem().second)) {
+            return false;
+          }
+        }
+        if (type == InsertType::DOES_NOT_EXIST) {
+          return false;
+        } else {
+          if (!cur) {
+            cur = (Node*)Allocator().allocate(sizeof(Node));
+            new (cur) Node(k, std::forward<Args>(args)...);
+          }
+          auto next = node->next_.load(std::memory_order_relaxed);
+          cur->next_.store(next, std::memory_order_relaxed);
+          if (next) {
+            next->acquire();
+          }
+          prev->store(cur, std::memory_order_release);
+          g.unlock();
+          // Release not under lock.
+          node->release();
+          return true;
+        }
+      }
+
+      prev = &node->next_;
+      node = node->next_.load(std::memory_order_relaxed);
+    }
+    if (type != InsertType::DOES_NOT_EXIST && type != InsertType::ANY) {
+      it.node_hazptr_.reset();
+      it.buckets_hazptr_.reset();
+      return false;
+    }
+    // Node not found, check for rehash on ANY
+    if (size_ >= load_factor_nodes_ && type == InsertType::ANY) {
+      if (max_size_ && size_ << 1 > max_size_) {
+        // Would exceed max size.
+        throw std::bad_alloc();
+      }
+      rehash(buckets->bucket_count_ << 1);
+
+      // Reload correct bucket.
+      buckets = buckets_.load(std::memory_order_relaxed);
+      it.buckets_hazptr_.reset(buckets);
+      idx = getIdx(buckets, h);
+      head = &buckets->buckets_[idx];
+      headnode = head->load(std::memory_order_relaxed);
+    }
+
+    // We found a slot to put the node.
+    size_++;
+    if (!cur) {
+      // InsertType::ANY
+      // OR DOES_NOT_EXIST, but only in the try_emplace case
+      DCHECK(type == InsertType::ANY || type == InsertType::DOES_NOT_EXIST);
+      cur = (Node*)Allocator().allocate(sizeof(Node));
+      new (cur) Node(k, std::forward<Args>(args)...);
+    }
+    cur->next_.store(headnode, std::memory_order_relaxed);
+    head->store(cur, std::memory_order_release);
+    it.setNode(cur, buckets, idx);
+    return true;
+  }
+
+  // Must hold lock.
+  void rehash(size_t bucket_count) {
+    auto buckets = buckets_.load(std::memory_order_relaxed);
+    auto newbuckets = (Buckets*)Allocator().allocate(sizeof(Buckets));
+    new (newbuckets) Buckets(bucket_count);
+
+    load_factor_nodes_ = bucket_count * load_factor_;
+
+    for (size_t i = 0; i < buckets->bucket_count_; i++) {
+      auto bucket = &buckets->buckets_[i];
+      auto node = bucket->load(std::memory_order_relaxed);
+      if (!node) {
+        continue;
+      }
+      auto h = HashFn()(node->getItem().first);
+      auto idx = getIdx(newbuckets, h);
+      // Reuse as long a chain as possible from the end.  Since the
+      // nodes don't have previous pointers, the longest last chain
+      // will be the same for both the previous hashmap and the new one,
+      // assuming all the nodes hash to the same bucket.
+      auto lastrun = node;
+      auto lastidx = idx;
+      auto count = 0;
+      auto last = node->next_.load(std::memory_order_relaxed);
+      for (; last != nullptr;
+           last = last->next_.load(std::memory_order_relaxed)) {
+        auto k = getIdx(newbuckets, HashFn()(last->getItem().first));
+        if (k != lastidx) {
+          lastidx = k;
+          lastrun = last;
+          count = 0;
+        }
+        count++;
+      }
+      // Set longest last run in new bucket, incrementing the refcount.
+      lastrun->acquire();
+      newbuckets->buckets_[lastidx].store(lastrun, std::memory_order_relaxed);
+      // Clone remaining nodes
+      for (; node != lastrun;
+           node = node->next_.load(std::memory_order_relaxed)) {
+        auto newnode = (Node*)Allocator().allocate(sizeof(Node));
+        new (newnode) Node(node);
+        auto k = getIdx(newbuckets, HashFn()(node->getItem().first));
+        auto prevhead = &newbuckets->buckets_[k];
+        newnode->next_.store(prevhead->load(std::memory_order_relaxed));
+        prevhead->store(newnode, std::memory_order_relaxed);
+      }
+    }
+
+    auto oldbuckets = buckets_.load(std::memory_order_relaxed);
+    buckets_.store(newbuckets, std::memory_order_release);
+    oldbuckets->retire(
+        folly::hazptr::default_hazptr_domain(),
+        concurrenthashmap::HazptrDeleter<Allocator>());
+  }
+
+  bool find(Iterator& res, const KeyType& k) {
+    folly::hazptr::hazptr_holder haznext;
+    auto h = HashFn()(k);
+    auto buckets = res.buckets_hazptr_.get_protected(buckets_);
+    auto idx = getIdx(buckets, h);
+    auto prev = &buckets->buckets_[idx];
+    auto node = res.node_hazptr_.get_protected(*prev);
+    while (node) {
+      if (KeyEqual()(k, node->getItem().first)) {
+        res.setNode(node, buckets, idx);
+        return true;
+      }
+      node = haznext.get_protected(node->next_);
+      haznext.swap(res.node_hazptr_);
+    }
+    return false;
+  }
+
+  // Listed separately because we need a prev pointer.
+  size_type erase(const key_type& key) {
+    return erase_internal(key, nullptr);
+  }
+
+  size_type erase_internal(const key_type& key, Iterator* iter) {
+    Node* node{nullptr};
+    auto h = HashFn()(key);
+    {
+      std::lock_guard<Mutex> g(m_);
+
+      auto buckets = buckets_.load(std::memory_order_relaxed);
+      auto idx = getIdx(buckets, h);
+      auto head = &buckets->buckets_[idx];
+      node = head->load(std::memory_order_relaxed);
+      Node* prev = nullptr;
+      auto headnode = node;
+      while (node) {
+        if (KeyEqual()(key, node->getItem().first)) {
+          auto next = node->next_.load(std::memory_order_relaxed);
+          if (next) {
+            next->acquire();
+          }
+          if (prev) {
+            prev->next_.store(next, std::memory_order_release);
+          } else {
+            // Must be head of list.
+            head->store(next, std::memory_order_release);
+          }
+
+          if (iter) {
+            iter->buckets_hazptr_.reset(buckets);
+            iter->setNode(
+                node->next_.load(std::memory_order_acquire), buckets, idx);
+          }
+          size_--;
+          break;
+        }
+        prev = node;
+        node = node->next_.load(std::memory_order_relaxed);
+      }
+    }
+    // Delete the node while not under the lock.
+    if (node) {
+      node->release();
+      return 1;
+    }
+    DCHECK(!iter);
+
+    return 0;
+  }
+
+  // Unfortunately because we are reusing nodes on rehash, we can't
+  // have prev pointers in the bucket chain.  We have to start the
+  // search from the bucket.
+  //
+  // This is a small departure from standard stl containers: erase may
+  // throw if hash or key_eq functions throw.
+  void erase(Iterator& res, Iterator& pos) {
+    auto cnt = erase_internal(pos->first, &res);
+    DCHECK(cnt == 1);
+  }
+
+  void clear() {
+    auto buckets = buckets_.load(std::memory_order_relaxed);
+    auto newbuckets = (Buckets*)Allocator().allocate(sizeof(Buckets));
+    new (newbuckets) Buckets(buckets->bucket_count_);
+    {
+      std::lock_guard<Mutex> g(m_);
+      buckets_.store(newbuckets, std::memory_order_release);
+      size_ = 0;
+    }
+    buckets->retire(
+        folly::hazptr::default_hazptr_domain(),
+        concurrenthashmap::HazptrDeleter<Allocator>());
+  }
+
+  void max_load_factor(float factor) {
+    std::lock_guard<Mutex> g(m_);
+    load_factor_ = factor;
+    auto buckets = buckets_.load(std::memory_order_relaxed);
+    load_factor_nodes_ = buckets->bucket_count_ * load_factor_;
+  }
+
+  Iterator cbegin() {
+    Iterator res;
+    auto buckets = res.buckets_hazptr_.get_protected(buckets_);
+    res.setNode(nullptr, buckets, 0);
+    res.next();
+    return res;
+  }
+
+  Iterator cend() {
+    return Iterator(nullptr);
+  }
+
+  // Could be optimized to avoid an extra pointer dereference by
+  // allocating buckets_ at the same time.
+  class Buckets : public folly::hazptr::hazptr_obj_base<
+                      Buckets,
+                      concurrenthashmap::HazptrDeleter<Allocator>> {
+   public:
+    explicit Buckets(size_t count) : bucket_count_(count) {
+      buckets_ =
+          (Atom<Node*>*)Allocator().allocate(sizeof(Atom<Node*>) * count);
+      new (buckets_) Atom<Node*>[ count ];
+      for (size_t i = 0; i < count; i++) {
+        buckets_[i].store(nullptr, std::memory_order_relaxed);
+      }
+    }
+    ~Buckets() {
+      for (size_t i = 0; i < bucket_count_; i++) {
+        auto elem = buckets_[i].load(std::memory_order_relaxed);
+        if (elem) {
+          elem->release();
+        }
+      }
+      Allocator().deallocate(
+          (uint8_t*)buckets_, sizeof(Atom<Node*>) * bucket_count_);
+    }
+
+    size_t bucket_count_;
+    Atom<Node*>* buckets_{nullptr};
+  };
+
+ public:
+  class Iterator {
+   public:
+    FOLLY_ALWAYS_INLINE Iterator() {}
+    FOLLY_ALWAYS_INLINE explicit Iterator(std::nullptr_t)
+        : buckets_hazptr_(nullptr), node_hazptr_(nullptr) {}
+    FOLLY_ALWAYS_INLINE ~Iterator() {}
+
+    void setNode(Node* node, Buckets* buckets, uint64_t idx) {
+      node_ = node;
+      buckets_ = buckets;
+      idx_ = idx;
+    }
+
+    const value_type& operator*() const {
+      DCHECK(node_);
+      return node_->getItem();
+    }
+
+    const value_type* operator->() const {
+      DCHECK(node_);
+      return &(node_->getItem());
+    }
+
+    const Iterator& operator++() {
+      DCHECK(node_);
+      node_ = node_hazptr_.get_protected(node_->next_);
+      if (!node_) {
+        ++idx_;
+        next();
+      }
+      return *this;
+    }
+
+    void next() {
+      while (!node_) {
+        if (idx_ >= buckets_->bucket_count_) {
+          break;
+        }
+        DCHECK(buckets_);
+        DCHECK(buckets_->buckets_);
+        node_ = node_hazptr_.get_protected(buckets_->buckets_[idx_]);
+        if (node_) {
+          break;
+        }
+        ++idx_;
+      }
+    }
+
+    Iterator operator++(int) {
+      auto prev = *this;
+      ++*this;
+      return prev;
+    }
+
+    bool operator==(const Iterator& o) const {
+      return node_ == o.node_;
+    }
+
+    bool operator!=(const Iterator& o) const {
+      return !(*this == o);
+    }
+
+    Iterator& operator=(const Iterator& o) {
+      node_ = o.node_;
+      node_hazptr_.reset(node_);
+      idx_ = o.idx_;
+      buckets_ = o.buckets_;
+      buckets_hazptr_.reset(buckets_);
+      return *this;
+    }
+
+    /* implicit */ Iterator(const Iterator& o) {
+      node_ = o.node_;
+      node_hazptr_.reset(node_);
+      idx_ = o.idx_;
+      buckets_ = o.buckets_;
+      buckets_hazptr_.reset(buckets_);
+    }
+
+    /* implicit */ Iterator(Iterator&& o) noexcept
+        : buckets_hazptr_(std::move(o.buckets_hazptr_)),
+          node_hazptr_(std::move(o.node_hazptr_)) {
+      node_ = o.node_;
+      buckets_ = o.buckets_;
+      idx_ = o.idx_;
+    }
+
+    // These are accessed directly from the functions above
+    folly::hazptr::hazptr_holder buckets_hazptr_;
+    folly::hazptr::hazptr_holder node_hazptr_;
+
+   private:
+    Node* node_{nullptr};
+    Buckets* buckets_{nullptr};
+    uint64_t idx_;
+  };
+
+ private:
+  // Shards have already used low ShardBits of the hash.
+  // Shift it over to use fresh bits.
+  uint64_t getIdx(Buckets* buckets, size_t hash) {
+    return (hash >> ShardBits) & (buckets->bucket_count_ - 1);
+  }
+
+  float load_factor_;
+  size_t load_factor_nodes_;
+  size_t size_{0};
+  size_t max_size_{0};
+  Atom<Buckets*> buckets_{nullptr};
+  Mutex m_;
+};
+}
+} // folly::detail namespace
diff --git a/folly/concurrency/test/ConcurrentHashMapTest.cpp b/folly/concurrency/test/ConcurrentHashMapTest.cpp
new file mode 100644 (file)
index 0000000..4a1d23d
--- /dev/null
@@ -0,0 +1,483 @@
+/*
+ * Copyright 2017-present Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <atomic>
+#include <memory>
+#include <thread>
+
+#include <folly/Hash.h>
+#include <folly/concurrency/ConcurrentHashMap.h>
+#include <folly/portability/GTest.h>
+#include <folly/test/DeterministicSchedule.h>
+
+using namespace folly::test;
+using namespace folly;
+using namespace std;
+
+DEFINE_int64(seed, 0, "Seed for random number generators");
+
+TEST(ConcurrentHashMap, MapTest) {
+  ConcurrentHashMap<uint64_t, uint64_t> foomap(3);
+  foomap.max_load_factor(1.05);
+  EXPECT_TRUE(foomap.empty());
+  EXPECT_EQ(foomap.find(1), foomap.cend());
+  auto r = foomap.insert(1, 0);
+  EXPECT_TRUE(r.second);
+  auto r2 = foomap.insert(1, 0);
+  EXPECT_EQ(r.first->second, 0);
+  EXPECT_EQ(r.first->first, 1);
+  EXPECT_EQ(r2.first->second, 0);
+  EXPECT_EQ(r2.first->first, 1);
+  EXPECT_EQ(r.first, r2.first);
+  EXPECT_TRUE(r.second);
+  EXPECT_FALSE(r2.second);
+  EXPECT_FALSE(foomap.empty());
+  EXPECT_TRUE(foomap.insert(std::make_pair(2, 0)).second);
+  EXPECT_TRUE(foomap.insert_or_assign(2, 0).second);
+  EXPECT_TRUE(foomap.assign_if_equal(2, 0, 3));
+  EXPECT_TRUE(foomap.insert(3, 0).second);
+  EXPECT_NE(foomap.find(1), foomap.cend());
+  EXPECT_NE(foomap.find(2), foomap.cend());
+  EXPECT_EQ(foomap.find(2)->second, 3);
+  EXPECT_EQ(foomap[2], 3);
+  EXPECT_EQ(foomap[20], 0);
+  EXPECT_EQ(foomap.at(20), 0);
+  EXPECT_FALSE(foomap.insert(1, 0).second);
+  auto l = foomap.find(1);
+  foomap.erase(l);
+  EXPECT_FALSE(foomap.erase(1));
+  EXPECT_EQ(foomap.find(1), foomap.cend());
+  auto res = foomap.find(2);
+  EXPECT_NE(res, foomap.cend());
+  EXPECT_EQ(3, res->second);
+  EXPECT_FALSE(foomap.empty());
+  foomap.clear();
+  EXPECT_TRUE(foomap.empty());
+}
+
+TEST(ConcurrentHashMap, MaxSizeTest) {
+  ConcurrentHashMap<uint64_t, uint64_t> foomap(2, 16);
+  bool insert_failed = false;
+  for (int i = 0; i < 32; i++) {
+    auto res = foomap.insert(0, 0);
+    if (!res.second) {
+      insert_failed = true;
+    }
+  }
+  EXPECT_TRUE(insert_failed);
+}
+
+TEST(ConcurrentHashMap, MoveTest) {
+  ConcurrentHashMap<uint64_t, uint64_t> foomap(2, 16);
+  auto other = std::move(foomap);
+  auto other2 = std::move(other);
+  other = std::move(other2);
+}
+
+struct foo {
+  static int moved;
+  static int copied;
+  foo(foo&& o) noexcept {
+    (void*)&o;
+    moved++;
+  }
+  foo& operator=(foo&& o) {
+    (void*)&o;
+    moved++;
+    return *this;
+  }
+  foo& operator=(const foo& o) {
+    (void*)&o;
+    copied++;
+    return *this;
+  }
+  foo(const foo& o) {
+    (void*)&o;
+    copied++;
+  }
+  foo() {}
+};
+int foo::moved{0};
+int foo::copied{0};
+
+TEST(ConcurrentHashMap, EmplaceTest) {
+  ConcurrentHashMap<uint64_t, foo> foomap(200);
+  foomap.insert(1, foo());
+  EXPECT_EQ(foo::moved, 0);
+  EXPECT_EQ(foo::copied, 1);
+  foo::copied = 0;
+  // The difference between emplace and try_emplace:
+  // If insertion fails, try_emplace does not move its argument
+  foomap.try_emplace(1, foo());
+  EXPECT_EQ(foo::moved, 0);
+  EXPECT_EQ(foo::copied, 0);
+  foomap.emplace(1, foo());
+  EXPECT_EQ(foo::moved, 1);
+  EXPECT_EQ(foo::copied, 0);
+}
+
+TEST(ConcurrentHashMap, MapResizeTest) {
+  ConcurrentHashMap<uint64_t, uint64_t> foomap(2);
+  EXPECT_EQ(foomap.find(1), foomap.cend());
+  EXPECT_TRUE(foomap.insert(1, 0).second);
+  EXPECT_TRUE(foomap.insert(2, 0).second);
+  EXPECT_TRUE(foomap.insert(3, 0).second);
+  EXPECT_TRUE(foomap.insert(4, 0).second);
+  foomap.reserve(512);
+  EXPECT_NE(foomap.find(1), foomap.cend());
+  EXPECT_NE(foomap.find(2), foomap.cend());
+  EXPECT_FALSE(foomap.insert(1, 0).second);
+  EXPECT_TRUE(foomap.erase(1));
+  EXPECT_EQ(foomap.find(1), foomap.cend());
+  auto res = foomap.find(2);
+  EXPECT_NE(res, foomap.cend());
+  if (res != foomap.cend()) {
+    EXPECT_EQ(0, res->second);
+  }
+}
+
+// Ensure we can insert objects without copy constructors.
+TEST(ConcurrentHashMap, MapNoCopiesTest) {
+  struct Uncopyable {
+    Uncopyable(int i) {
+      (void*)&i;
+    }
+    Uncopyable(const Uncopyable& that) = delete;
+  };
+  ConcurrentHashMap<uint64_t, Uncopyable> foomap(2);
+  EXPECT_TRUE(foomap.try_emplace(1, 1).second);
+  EXPECT_TRUE(foomap.try_emplace(2, 2).second);
+  auto res = foomap.find(2);
+  EXPECT_NE(res, foomap.cend());
+
+  EXPECT_TRUE(foomap.try_emplace(3, 3).second);
+
+  auto res2 = foomap.find(2);
+  EXPECT_NE(res2, foomap.cend());
+  EXPECT_EQ(&(res->second), &(res2->second));
+}
+
+TEST(ConcurrentHashMap, MapUpdateTest) {
+  ConcurrentHashMap<uint64_t, uint64_t> foomap(2);
+  EXPECT_TRUE(foomap.insert(1, 10).second);
+  EXPECT_TRUE(bool(foomap.assign(1, 11)));
+  auto res = foomap.find(1);
+  EXPECT_NE(res, foomap.cend());
+  EXPECT_EQ(11, res->second);
+}
+
+TEST(ConcurrentHashMap, MapIterateTest2) {
+  ConcurrentHashMap<uint64_t, uint64_t> foomap(2);
+  auto begin = foomap.cbegin();
+  auto end = foomap.cend();
+  EXPECT_EQ(begin, end);
+}
+
+TEST(ConcurrentHashMap, MapIterateTest) {
+  ConcurrentHashMap<uint64_t, uint64_t> foomap(2);
+  EXPECT_EQ(foomap.cbegin(), foomap.cend());
+  EXPECT_TRUE(foomap.insert(1, 1).second);
+  EXPECT_TRUE(foomap.insert(2, 2).second);
+  auto iter = foomap.cbegin();
+  EXPECT_NE(iter, foomap.cend());
+  EXPECT_EQ(iter->first, 1);
+  EXPECT_EQ(iter->second, 1);
+  iter++;
+  EXPECT_NE(iter, foomap.cend());
+  EXPECT_EQ(iter->first, 2);
+  EXPECT_EQ(iter->second, 2);
+  iter++;
+  EXPECT_EQ(iter, foomap.cend());
+
+  int count = 0;
+  for (auto it = foomap.cbegin(); it != foomap.cend(); it++) {
+    count++;
+  }
+  EXPECT_EQ(count, 2);
+}
+
+// TODO: hazptrs must support DeterministicSchedule
+
+#define Atom std::atomic // DeterministicAtomic
+#define Mutex std::mutex // DeterministicMutex
+#define lib std // DeterministicSchedule
+#define join t.join() // DeterministicSchedule::join(t)
+// #define Atom DeterministicAtomic
+// #define Mutex DeterministicMutex
+// #define lib DeterministicSchedule
+// #define join DeterministicSchedule::join(t)
+
+TEST(ConcurrentHashMap, UpdateStressTest) {
+  DeterministicSchedule sched(DeterministicSchedule::uniform(FLAGS_seed));
+
+  // size must match iters for this test.
+  unsigned size = 128 * 128;
+  unsigned iters = size;
+  ConcurrentHashMap<
+      unsigned long,
+      unsigned long,
+      std::hash<unsigned long>,
+      std::equal_to<unsigned long>,
+      std::allocator<uint8_t>,
+      8,
+      Atom,
+      Mutex>
+      m(2);
+
+  for (uint i = 0; i < size; i++) {
+    m.insert(i, i);
+  }
+  std::vector<std::thread> threads;
+  unsigned int num_threads = 32;
+  for (uint t = 0; t < num_threads; t++) {
+    threads.push_back(lib::thread([&, t]() {
+      int offset = (iters * t / num_threads);
+      for (uint i = 0; i < iters / num_threads; i++) {
+        unsigned long k = folly::hash::jenkins_rev_mix32((i + offset));
+        k = k % (iters / num_threads) + offset;
+        unsigned long val = 3;
+        auto res = m.find(k);
+        EXPECT_NE(res, m.cend());
+        EXPECT_EQ(k, res->second);
+        auto r = m.assign(k, res->second);
+        EXPECT_TRUE(r);
+        res = m.find(k);
+        EXPECT_NE(res, m.cend());
+        EXPECT_EQ(k, res->second);
+        // Another random insertion to force table resizes
+        val = size + i + offset;
+        EXPECT_TRUE(m.insert(val, val).second);
+      }
+    }));
+  }
+  for (auto& t : threads) {
+    join;
+  }
+}
+
+TEST(ConcurrentHashMap, EraseStressTest) {
+  DeterministicSchedule sched(DeterministicSchedule::uniform(FLAGS_seed));
+
+  unsigned size = 2;
+  unsigned iters = size * 128 * 2;
+  ConcurrentHashMap<
+      unsigned long,
+      unsigned long,
+      std::hash<unsigned long>,
+      std::equal_to<unsigned long>,
+      std::allocator<uint8_t>,
+      8,
+      Atom,
+      Mutex>
+      m(2);
+
+  for (uint i = 0; i < size; i++) {
+    unsigned long k = folly::hash::jenkins_rev_mix32(i);
+    m.insert(k, k);
+  }
+  std::vector<std::thread> threads;
+  unsigned int num_threads = 32;
+  for (uint t = 0; t < num_threads; t++) {
+    threads.push_back(lib::thread([&, t]() {
+      int offset = (iters * t / num_threads);
+      for (uint i = 0; i < iters / num_threads; i++) {
+        unsigned long k = folly::hash::jenkins_rev_mix32((i + offset));
+        unsigned long val;
+        auto res = m.insert(k, k).second;
+        if (res) {
+          res = m.erase(k);
+          if (!res) {
+            printf("Faulre to erase thread %i val %li\n", t, k);
+            exit(0);
+          }
+          EXPECT_TRUE(res);
+        }
+        res = m.insert(k, k).second;
+        if (res) {
+          res = bool(m.assign(k, k));
+          if (!res) {
+            printf("Thread %i update fail %li res%i\n", t, k, res);
+            exit(0);
+          }
+          EXPECT_TRUE(res);
+          auto res = m.find(k);
+          if (res == m.cend()) {
+            printf("Thread %i lookup fail %li\n", t, k);
+            exit(0);
+          }
+          EXPECT_EQ(k, res->second);
+        }
+      }
+    }));
+  }
+  for (auto& t : threads) {
+    join;
+  }
+}
+
+TEST(ConcurrentHashMap, IterateStressTest) {
+  DeterministicSchedule sched(DeterministicSchedule::uniform(FLAGS_seed));
+
+  unsigned size = 2;
+  unsigned iters = size * 128 * 2;
+  ConcurrentHashMap<
+      unsigned long,
+      unsigned long,
+      std::hash<unsigned long>,
+      std::equal_to<unsigned long>,
+      std::allocator<uint8_t>,
+      8,
+      Atom,
+      Mutex>
+      m(2);
+
+  for (uint i = 0; i < size; i++) {
+    unsigned long k = folly::hash::jenkins_rev_mix32(i);
+    m.insert(k, k);
+  }
+  for (uint i = 0; i < 10; i++) {
+    m.insert(i, i);
+  }
+  std::vector<std::thread> threads;
+  unsigned int num_threads = 32;
+  for (uint t = 0; t < num_threads; t++) {
+    threads.push_back(lib::thread([&, t]() {
+      int offset = (iters * t / num_threads);
+      for (uint i = 0; i < iters / num_threads; i++) {
+        unsigned long k = folly::hash::jenkins_rev_mix32((i + offset));
+        unsigned long val;
+        auto res = m.insert(k, k).second;
+        if (res) {
+          res = m.erase(k);
+          if (!res) {
+            printf("Faulre to erase thread %i val %li\n", t, k);
+            exit(0);
+          }
+          EXPECT_TRUE(res);
+        }
+        int count = 0;
+        for (auto it = m.cbegin(); it != m.cend(); it++) {
+          printf("Item is %li\n", it->first);
+          if (it->first < 10) {
+            count++;
+          }
+        }
+        EXPECT_EQ(count, 10);
+      }
+    }));
+  }
+  for (auto& t : threads) {
+    join;
+  }
+}
+
+TEST(ConcurrentHashMap, insertStressTest) {
+  DeterministicSchedule sched(DeterministicSchedule::uniform(FLAGS_seed));
+
+  unsigned size = 2;
+  unsigned iters = size * 64 * 4;
+  ConcurrentHashMap<
+      unsigned long,
+      unsigned long,
+      std::hash<unsigned long>,
+      std::equal_to<unsigned long>,
+      std::allocator<uint8_t>,
+      8,
+      Atom,
+      Mutex>
+      m(2);
+
+  EXPECT_TRUE(m.insert(0, 0).second);
+  EXPECT_FALSE(m.insert(0, 0).second);
+  std::vector<std::thread> threads;
+  unsigned int num_threads = 32;
+  for (uint t = 0; t < num_threads; t++) {
+    threads.push_back(lib::thread([&, t]() {
+      int offset = (iters * t / num_threads);
+      for (uint i = 0; i < iters / num_threads; i++) {
+        auto var = offset + i + 1;
+        EXPECT_TRUE(m.insert(var, var).second);
+        EXPECT_FALSE(m.insert(0, 0).second);
+      }
+    }));
+  }
+  for (auto& t : threads) {
+    join;
+  }
+}
+
+TEST(ConcurrentHashMap, assignStressTest) {
+  DeterministicSchedule sched(DeterministicSchedule::uniform(FLAGS_seed));
+
+  unsigned size = 2;
+  unsigned iters = size * 64 * 4;
+  struct big_value {
+    uint64_t v1;
+    uint64_t v2;
+    uint64_t v3;
+    uint64_t v4;
+    uint64_t v5;
+    uint64_t v6;
+    uint64_t v7;
+    uint64_t v8;
+    void set(uint64_t v) {
+      v1 = v2 = v3 = v4 = v5 = v6 = v7 = v8 = v;
+    }
+    void check() const {
+      auto v = v1;
+      EXPECT_EQ(v, v8);
+      EXPECT_EQ(v, v7);
+      EXPECT_EQ(v, v6);
+      EXPECT_EQ(v, v5);
+      EXPECT_EQ(v, v4);
+      EXPECT_EQ(v, v3);
+      EXPECT_EQ(v, v2);
+    }
+  };
+  ConcurrentHashMap<
+      unsigned long,
+      big_value,
+      std::hash<unsigned long>,
+      std::equal_to<unsigned long>,
+      std::allocator<uint8_t>,
+      8,
+      Atom,
+      Mutex>
+      m(2);
+
+  for (uint i = 0; i < iters; i++) {
+    big_value a;
+    a.set(i);
+    m.insert(i, a);
+  }
+
+  std::vector<std::thread> threads;
+  unsigned int num_threads = 32;
+  for (uint t = 0; t < num_threads; t++) {
+    threads.push_back(lib::thread([&]() {
+      for (uint i = 0; i < iters; i++) {
+        auto res = m.find(i);
+        EXPECT_NE(res, m.cend());
+        res->second.check();
+        big_value b;
+        b.set(res->second.v1 + 1);
+        m.assign(i, b);
+      }
+    }));
+  }
+  for (auto& t : threads) {
+    join;
+  }
+}
index 60cf5ecd6ec5cbfeccd17df32a94aa8ddc294711..4dd0de889c3113cce692721edb8fdd7b5f2b1beb 100644 (file)
@@ -611,7 +611,7 @@ inline bool hazptr_tc::put(hazptr_rec* hprec) {
   return false;
 }
 
-inline class hazptr_tc& hazptr_tc() {
+FOLLY_ALWAYS_INLINE class hazptr_tc& hazptr_tc() {
   static thread_local class hazptr_tc tc;
   DEBUG_PRINT(&tc);
   return tc;