Heterogeneous lookups for sorted_vector types
authorYedidya Feldblum <yfeldblum@fb.com>
Wed, 8 Nov 2017 17:22:17 +0000 (09:22 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Nov 2017 17:49:39 +0000 (09:49 -0800)
Summary:
[Folly] Heterogeneous lookups for `sorted_vector` types.

When the `Compare` type has member type or alias `is_transparent`, enable template overloads of `count`, `find`, `lower_bound`, `upper_bound`, and `equal_range` on both `sorted_vector_set` and `sorted_vector_map`.

This is the protocol found in the equivalent `std::set` and `std::map` member functions.

> This overload only participates in overload resolution if the qualified-id `Compare::is_transparent` is valid and denotes a type. They allow calling this function without constructing an instance of `Key`.
>
> http://en.cppreference.com/w/cpp/container/set/count (same wording in all 10 cases)

Reviewed By: nbronson

Differential Revision: D6256989

fbshipit-source-id: a40a181453a019564e8f7674e1e07e241d5ab068

folly/sorted_vector_types.h
folly/test/sorted_vector_test.cpp

index d011a83d401276e16741fe2d4209183ad3d51908..2cff300f9ec3d8a05a56cafe407d27384b33c053 100644 (file)
@@ -68,6 +68,8 @@
 #include <vector>
 
 #include <boost/operators.hpp>
+
+#include <folly/Traits.h>
 #include <folly/portability/BitsFunctexcept.h>
 
 namespace folly {
@@ -76,6 +78,18 @@ namespace folly {
 
 namespace detail {
 
+template <typename, typename Compare, typename Key, typename T>
+struct sorted_vector_enable_if_is_transparent {};
+
+template <typename Compare, typename Key, typename T>
+struct sorted_vector_enable_if_is_transparent<
+    void_t<typename Compare::is_transparent>,
+    Compare,
+    Key,
+    T> {
+  using type = T;
+};
+
 // This wrapper goes around a GrowthPolicy and provides iterator
 // preservation semantics, but only if the growth policy is not the
 // default (i.e. nothing).
@@ -212,6 +226,10 @@ class sorted_vector_set
   detail::growth_policy_wrapper<GrowthPolicy>&
   get_growth_policy() { return *this; }
 
+  template <typename K, typename V, typename C = Compare>
+  using if_is_transparent =
+      _t<detail::sorted_vector_enable_if_is_transparent<void, C, K, V>>;
+
  public:
   typedef T       value_type;
   typedef T       key_type;
@@ -343,25 +361,32 @@ class sorted_vector_set
   }
 
   iterator find(const key_type& key) {
-    iterator it = lower_bound(key);
-    if (it == end() || !key_comp()(key, *it)) {
-      return it;
-    }
-    return end();
+    return find(*this, key);
   }
 
   const_iterator find(const key_type& key) const {
-    const_iterator it = lower_bound(key);
-    if (it == end() || !key_comp()(key, *it)) {
-      return it;
-    }
-    return end();
+    return find(*this, key);
+  }
+
+  template <typename K>
+  if_is_transparent<K, iterator> find(const K& key) {
+    return find(*this, key);
+  }
+
+  template <typename K>
+  if_is_transparent<K, const_iterator> find(const K& key) const {
+    return find(*this, key);
   }
 
   size_type count(const key_type& key) const {
     return find(key) == end() ? 0 : 1;
   }
 
+  template <typename K>
+  if_is_transparent<K, size_type> count(const K& key) const {
+    return find(key) == end() ? 0 : 1;
+  }
+
   iterator lower_bound(const key_type& key) {
     return std::lower_bound(begin(), end(), key, key_comp());
   }
@@ -370,6 +395,16 @@ class sorted_vector_set
     return std::lower_bound(begin(), end(), key, key_comp());
   }
 
+  template <typename K>
+  if_is_transparent<K, iterator> lower_bound(const K& key) {
+    return std::lower_bound(begin(), end(), key, key_comp());
+  }
+
+  template <typename K>
+  if_is_transparent<K, const_iterator> lower_bound(const K& key) const {
+    return std::lower_bound(begin(), end(), key, key_comp());
+  }
+
   iterator upper_bound(const key_type& key) {
     return std::upper_bound(begin(), end(), key, key_comp());
   }
@@ -378,12 +413,34 @@ class sorted_vector_set
     return std::upper_bound(begin(), end(), key, key_comp());
   }
 
-  std::pair<iterator,iterator> equal_range(const key_type& key) {
+  template <typename K>
+  if_is_transparent<K, iterator> upper_bound(const K& key) {
+    return std::upper_bound(begin(), end(), key, key_comp());
+  }
+
+  template <typename K>
+  if_is_transparent<K, const_iterator> upper_bound(const K& key) const {
+    return std::upper_bound(begin(), end(), key, key_comp());
+  }
+
+  std::pair<iterator, iterator> equal_range(const key_type& key) {
+    return std::equal_range(begin(), end(), key, key_comp());
+  }
+
+  std::pair<const_iterator, const_iterator> equal_range(
+      const key_type& key) const {
+    return std::equal_range(begin(), end(), key, key_comp());
+  }
+
+  template <typename K>
+  if_is_transparent<K, std::pair<iterator, iterator>> equal_range(
+      const K& key) {
     return std::equal_range(begin(), end(), key, key_comp());
   }
 
-  std::pair<const_iterator,const_iterator>
-  equal_range(const key_type& key) const {
+  template <typename K>
+  if_is_transparent<K, std::pair<const_iterator, const_iterator>> equal_range(
+      const K& key) const {
     return std::equal_range(begin(), end(), key, key_comp());
   }
 
@@ -423,6 +480,20 @@ class sorted_vector_set
     {}
     ContainerT cont_;
   } m_;
+
+  template <typename Self>
+  using self_iterator_t = _t<
+      std::conditional<std::is_const<Self>::value, const_iterator, iterator>>;
+
+  template <typename Self, typename K>
+  static self_iterator_t<Self> find(Self& self, K const& key) {
+    auto end = self.end();
+    auto it = self.lower_bound(key);
+    if (it == end || !self.key_comp()(key, *it)) {
+      return it;
+    }
+    return end;
+  }
 };
 
 // Swap function that can be found using ADL.
@@ -465,6 +536,10 @@ class sorted_vector_map
   detail::growth_policy_wrapper<GrowthPolicy>&
   get_growth_policy() { return *this; }
 
+  template <typename K, typename V, typename C = Compare>
+  using if_is_transparent =
+      _t<detail::sorted_vector_enable_if_is_transparent<void, C, K, V>>;
+
  public:
   typedef Key                                       key_type;
   typedef Value                                     mapped_type;
@@ -599,19 +674,21 @@ class sorted_vector_map
   }
 
   iterator find(const key_type& key) {
-    iterator it = lower_bound(key);
-    if (it == end() || !key_comp()(key, it->first)) {
-      return it;
-    }
-    return end();
+    return find(*this, key);
   }
 
   const_iterator find(const key_type& key) const {
-    const_iterator it = lower_bound(key);
-    if (it == end() || !key_comp()(key, it->first)) {
-      return it;
-    }
-    return end();
+    return find(*this, key);
+  }
+
+  template <typename K>
+  if_is_transparent<K, iterator> find(const K& key) {
+    return find(*this, key);
+  }
+
+  template <typename K>
+  if_is_transparent<K, const_iterator> find(const K& key) const {
+    return find(*this, key);
   }
 
   mapped_type& at(const key_type& key) {
@@ -634,54 +711,66 @@ class sorted_vector_map
     return find(key) == end() ? 0 : 1;
   }
 
+  template <typename K>
+  if_is_transparent<K, size_type> count(const K& key) const {
+    return find(key) == end() ? 0 : 1;
+  }
+
   iterator lower_bound(const key_type& key) {
-    auto c = key_comp();
-    auto f = [&](const value_type& a, const key_type& b) {
-      return c(a.first, b);
-    };
-    return std::lower_bound(begin(), end(), key, f);
+    return lower_bound(*this, key);
   }
 
   const_iterator lower_bound(const key_type& key) const {
-    auto c = key_comp();
-    auto f = [&](const value_type& a, const key_type& b) {
-      return c(a.first, b);
-    };
-    return std::lower_bound(begin(), end(), key, f);
+    return lower_bound(*this, key);
+  }
+
+  template <typename K>
+  if_is_transparent<K, iterator> lower_bound(const K& key) {
+    return lower_bound(*this, key);
+  }
+
+  template <typename K>
+  if_is_transparent<K, const_iterator> lower_bound(const K& key) const {
+    return lower_bound(*this, key);
   }
 
   iterator upper_bound(const key_type& key) {
-    auto c = key_comp();
-    auto f = [&](const key_type& a, const value_type& b) {
-      return c(a, b.first);
-    };
-    return std::upper_bound(begin(), end(), key, f);
+    return upper_bound(*this, key);
   }
 
   const_iterator upper_bound(const key_type& key) const {
-    auto c = key_comp();
-    auto f = [&](const key_type& a, const value_type& b) {
-      return c(a, b.first);
-    };
-    return std::upper_bound(begin(), end(), key, f);
+    return upper_bound(*this, key);
   }
 
-  std::pair<iterator,iterator> equal_range(const key_type& key) {
-    // Note: std::equal_range can't be passed a functor that takes
-    // argument types different from the iterator value_type, so we
-    // have to do this.
-    iterator low = lower_bound(key);
-    auto c = key_comp();
-    auto f = [&](const key_type& a, const value_type& b) {
-      return c(a, b.first);
-    };
-    iterator high = std::upper_bound(low, end(), key, f);
-    return std::make_pair(low, high);
+  template <typename K>
+  if_is_transparent<K, iterator> upper_bound(const K& key) {
+    return upper_bound(*this, key);
+  }
+
+  template <typename K>
+  if_is_transparent<K, const_iterator> upper_bound(const K& key) const {
+    return upper_bound(*this, key);
   }
 
-  std::pair<const_iterator,const_iterator>
-  equal_range(const key_type& key) const {
-    return const_cast<sorted_vector_map*>(this)->equal_range(key);
+  std::pair<iterator, iterator> equal_range(const key_type& key) {
+    return equal_range(*this, key);
+  }
+
+  std::pair<const_iterator, const_iterator> equal_range(
+      const key_type& key) const {
+    return equal_range(*this, key);
+  }
+
+  template <typename K>
+  if_is_transparent<K, std::pair<iterator, iterator>> equal_range(
+      const K& key) {
+    return equal_range(*this, key);
+  }
+
+  template <typename K>
+  if_is_transparent<K, std::pair<const_iterator, const_iterator>> equal_range(
+      const K& key) const {
+    return equal_range(*this, key);
   }
 
   // Nothrow as long as swap() on the Compare type is nothrow.
@@ -719,6 +808,46 @@ class sorted_vector_map
     {}
     ContainerT cont_;
   } m_;
+
+  template <typename Self>
+  using self_iterator_t = _t<
+      std::conditional<std::is_const<Self>::value, const_iterator, iterator>>;
+
+  template <typename Self, typename K>
+  static self_iterator_t<Self> find(Self& self, K const& key) {
+    auto end = self.end();
+    auto it = self.lower_bound(key);
+    if (it == end || !self.key_comp()(key, it->first)) {
+      return it;
+    }
+    return end;
+  }
+
+  template <typename Self, typename K>
+  static self_iterator_t<Self> lower_bound(Self& self, K const& key) {
+    auto f = [c = self.key_comp()](value_type const& a, K const& b) {
+      return c(a.first, b);
+    };
+    return std::lower_bound(self.begin(), self.end(), key, f);
+  }
+
+  template <typename Self, typename K>
+  static self_iterator_t<Self> upper_bound(Self& self, K const& key) {
+    auto f = [c = self.key_comp()](K const& a, value_type const& b) {
+      return c(a, b.first);
+    };
+    return std::upper_bound(self.begin(), self.end(), key, f);
+  }
+
+  template <typename Self, typename K>
+  static std::pair<self_iterator_t<Self>, self_iterator_t<Self>> equal_range(
+      Self& self,
+      K const& key) {
+    // Note: std::equal_range can't be passed a functor that takes
+    // argument types different from the iterator value_type, so we
+    // have to do this.
+    return {lower_bound(self, key), upper_bound(self, key)};
+  }
 };
 
 // Swap function that can be found using ADL.
index fa9c977e09aa4a003f4064d4dfccad8eac0ef3f1..7efa4c8f0cfd351060f257cb7fcf7faddda74f34 100644 (file)
@@ -76,6 +76,27 @@ struct CountCopyCtor {
   int count_;
 };
 
+struct Opaque {
+  int value;
+  friend bool operator==(Opaque a, Opaque b) {
+    return a.value == b.value;
+  }
+  friend bool operator<(Opaque a, Opaque b) {
+    return a.value < b.value;
+  }
+  struct Compare : std::less<int>, std::less<Opaque> {
+    using is_transparent = void;
+    using std::less<int>::operator();
+    using std::less<Opaque>::operator();
+    bool operator()(int a, Opaque b) const {
+      return std::less<int>::operator()(a, b.value);
+    }
+    bool operator()(Opaque a, int b) const {
+      return std::less<int>::operator()(a.value, b);
+    }
+  };
+};
+
 } // namespace
 
 TEST(SortedVectorTypes, SimpleSetTest) {
@@ -145,6 +166,73 @@ TEST(SortedVectorTypes, SimpleSetTest) {
   EXPECT_TRUE(cpy2 == cpy);
 }
 
+TEST(SortedVectorTypes, TransparentSetTest) {
+  sorted_vector_set<Opaque, Opaque::Compare> s;
+  EXPECT_TRUE(s.empty());
+  for (int i = 0; i < 1000; ++i) {
+    s.insert(Opaque{rand() % 100000});
+  }
+  EXPECT_FALSE(s.empty());
+  check_invariant(s);
+
+  sorted_vector_set<Opaque, Opaque::Compare> s2;
+  s2.insert(s.begin(), s.end());
+  check_invariant(s2);
+  EXPECT_TRUE(s == s2);
+
+  auto it = s2.lower_bound(32);
+  if (it->value == 32) {
+    s2.erase(it);
+    it = s2.lower_bound(32);
+  }
+  check_invariant(s2);
+  auto oldSz = s2.size();
+  s2.insert(it, Opaque{32});
+  EXPECT_TRUE(s2.size() == oldSz + 1);
+  check_invariant(s2);
+
+  const sorted_vector_set<Opaque, Opaque::Compare>& cs2 = s2;
+  auto range = cs2.equal_range(32);
+  auto lbound = cs2.lower_bound(32);
+  auto ubound = cs2.upper_bound(32);
+  EXPECT_TRUE(range.first == lbound);
+  EXPECT_TRUE(range.second == ubound);
+  EXPECT_TRUE(range.first != cs2.end());
+  EXPECT_TRUE(range.second != cs2.end());
+  EXPECT_TRUE(cs2.count(32) == 1);
+  EXPECT_FALSE(cs2.find(32) == cs2.end());
+
+  // Bad insert hint.
+  s2.insert(s2.begin() + 3, Opaque{33});
+  EXPECT_TRUE(s2.find(33) != s2.begin());
+  EXPECT_TRUE(s2.find(33) != s2.end());
+  check_invariant(s2);
+  s2.erase(Opaque{33});
+  check_invariant(s2);
+
+  it = s2.find(32);
+  EXPECT_FALSE(it == s2.end());
+  s2.erase(it);
+  EXPECT_TRUE(s2.size() == oldSz);
+  check_invariant(s2);
+
+  sorted_vector_set<Opaque, Opaque::Compare> cpy(s);
+  check_invariant(cpy);
+  EXPECT_TRUE(cpy == s);
+  sorted_vector_set<Opaque, Opaque::Compare> cpy2(s);
+  cpy2.insert(Opaque{100001});
+  EXPECT_TRUE(cpy2 != cpy);
+  EXPECT_TRUE(cpy2 != s);
+  check_invariant(cpy2);
+  EXPECT_TRUE(cpy2.count(100001) == 1);
+  s.swap(cpy2);
+  check_invariant(cpy2);
+  check_invariant(s);
+  EXPECT_TRUE(s != cpy);
+  EXPECT_TRUE(s != cpy2);
+  EXPECT_TRUE(cpy2 == cpy);
+}
+
 TEST(SortedVectorTypes, BadHints) {
   for (int toInsert = -1; toInsert <= 7; ++toInsert) {
     for (int hintPos = 0; hintPos <= 4; ++hintPos) {
@@ -221,6 +309,67 @@ TEST(SortedVectorTypes, SimpleMapTest) {
   check_invariant(m);
 }
 
+TEST(SortedVectorTypes, TransparentMapTest) {
+  sorted_vector_map<Opaque, float, Opaque::Compare> m;
+  for (int i = 0; i < 1000; ++i) {
+    m[Opaque{i}] = i / 1000.0;
+  }
+  check_invariant(m);
+
+  m[Opaque{32}] = 100.0;
+  check_invariant(m);
+  EXPECT_TRUE(m.count(32) == 1);
+  EXPECT_DOUBLE_EQ(100.0, m.at(Opaque{32}));
+  EXPECT_FALSE(m.find(32) == m.end());
+  m.erase(Opaque{32});
+  EXPECT_TRUE(m.find(32) == m.end());
+  check_invariant(m);
+  EXPECT_THROW(m.at(Opaque{32}), std::out_of_range);
+
+  sorted_vector_map<Opaque, float, Opaque::Compare> m2 = m;
+  EXPECT_TRUE(m2 == m);
+  EXPECT_FALSE(m2 != m);
+  auto it = m2.lower_bound(1 << 20);
+  EXPECT_TRUE(it == m2.end());
+  m2.insert(it, std::make_pair(Opaque{1 << 20}, 10.0f));
+  check_invariant(m2);
+  EXPECT_TRUE(m2.count(1 << 20) == 1);
+  EXPECT_TRUE(m < m2);
+  EXPECT_TRUE(m <= m2);
+
+  const sorted_vector_map<Opaque, float, Opaque::Compare>& cm = m;
+  auto range = cm.equal_range(42);
+  auto lbound = cm.lower_bound(42);
+  auto ubound = cm.upper_bound(42);
+  EXPECT_TRUE(range.first == lbound);
+  EXPECT_TRUE(range.second == ubound);
+  EXPECT_FALSE(range.first == cm.end());
+  EXPECT_FALSE(range.second == cm.end());
+  m.erase(m.lower_bound(42));
+  check_invariant(m);
+
+  sorted_vector_map<Opaque, float, Opaque::Compare> m3;
+  m3.insert(m2.begin(), m2.end());
+  check_invariant(m3);
+  EXPECT_TRUE(m3 == m2);
+  EXPECT_FALSE(m3 == m);
+
+  EXPECT_TRUE(m != m2);
+  EXPECT_TRUE(m2 == m3);
+  EXPECT_TRUE(m3 != m);
+  m.swap(m3);
+  check_invariant(m);
+  check_invariant(m2);
+  check_invariant(m3);
+  EXPECT_TRUE(m3 != m2);
+  EXPECT_TRUE(m3 != m);
+  EXPECT_TRUE(m == m2);
+
+  // Bad insert hint.
+  m.insert(m.begin() + 3, std::make_pair(Opaque{1 << 15}, 1.0f));
+  check_invariant(m);
+}
+
 TEST(SortedVectorTypes, Sizes) {
   EXPECT_EQ(sizeof(sorted_vector_set<int>),
             sizeof(std::vector<int>));