Domain destruction fixes
[folly.git] / folly / experimental / hazptr / hazptr-impl.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /* override-include-guard */
18 #ifndef HAZPTR_H
19 #error "This should only be included by hazptr.h"
20 #endif
21
22 #include <folly/experimental/hazptr/debug.h>
23
24 #include <unordered_set>
25
26 namespace folly {
27 namespace hazptr {
28
29 /** hazptr_domain */
30
31 constexpr hazptr_domain::hazptr_domain(memory_resource* mr) noexcept
32     : mr_(mr) {}
33
34 /** hazptr_obj_base */
35
36 template <typename T, typename D>
37 inline void hazptr_obj_base<T, D>::retire(hazptr_domain& domain, D deleter) {
38   DEBUG_PRINT(this << " " << &domain);
39   deleter_ = std::move(deleter);
40   reclaim_ = [](hazptr_obj* p) {
41     auto hobp = static_cast<hazptr_obj_base*>(p);
42     auto obj = static_cast<T*>(hobp);
43     hobp->deleter_(obj);
44   };
45   domain.objRetire(this);
46 }
47
48 /** hazptr_rec */
49
50 class hazptr_rec {
51   friend class hazptr_domain;
52   template <typename> friend class hazptr_owner;
53
54   std::atomic<const void*> hazptr_ = {nullptr};
55   hazptr_rec* next_ = {nullptr};
56   std::atomic<bool> active_ = {false};
57
58   void set(const void* p) noexcept;
59   const void* get() const noexcept;
60   void clear() noexcept;
61   void release() noexcept;
62 };
63
64 /** hazptr_owner */
65
66 template <typename T>
67 inline hazptr_owner<T>::hazptr_owner(hazptr_domain& domain) {
68   domain_ = &domain;
69   hazptr_ = domain_->hazptrAcquire();
70   DEBUG_PRINT(this << " " << domain_ << " " << hazptr_);
71   if (hazptr_ == nullptr) { std::bad_alloc e; throw e; }
72 }
73
74 template <typename T>
75 hazptr_owner<T>::~hazptr_owner() {
76   DEBUG_PRINT(this);
77   domain_->hazptrRelease(hazptr_);
78 }
79
80 template <typename T>
81 template <typename A>
82 inline bool hazptr_owner<T>::try_protect(T*& ptr, const A& src) noexcept {
83   static_assert(
84       std::is_same<decltype(std::declval<A>().load()), T*>::value,
85       "Return type of A::load() must be T*");
86   DEBUG_PRINT(this << " " << ptr << " " << &src);
87   set(ptr);
88   T* p = src.load();
89   if (p != ptr) {
90     ptr = p;
91     clear();
92     return false;
93   }
94   return true;
95 }
96
97 template <typename T>
98 template <typename A>
99 inline T* hazptr_owner<T>::get_protected(const A& src) noexcept {
100   static_assert(
101       std::is_same<decltype(std::declval<A>().load()), T*>::value,
102       "Return type of A::load() must be T*");
103   T* p = src.load();
104   while (!try_protect(p, src)) {}
105   DEBUG_PRINT(this << " " << p << " " << &src);
106   return p;
107 }
108
109 template <typename T>
110 inline void hazptr_owner<T>::set(const T* ptr) noexcept {
111   auto p = static_cast<hazptr_obj*>(const_cast<T*>(ptr));
112   DEBUG_PRINT(this << " " << ptr << " p:" << p);
113   hazptr_->set(p);
114 }
115
116 template <typename T>
117 inline void hazptr_owner<T>::clear() noexcept {
118   DEBUG_PRINT(this);
119   hazptr_->clear();
120 }
121
122 template <typename T>
123 inline void hazptr_owner<T>::swap(hazptr_owner<T>& rhs) noexcept {
124   DEBUG_PRINT(
125     this << " " <<  this->hazptr_ << " " << this->domain_ << " -- "
126     << &rhs << " " << rhs.hazptr_ << " " << rhs.domain_);
127   std::swap(this->domain_, rhs.domain_);
128   std::swap(this->hazptr_, rhs.hazptr_);
129 }
130
131 template <typename T>
132 inline void swap(hazptr_owner<T>& lhs, hazptr_owner<T>& rhs) noexcept {
133   lhs.swap(rhs);
134 }
135
136 ////////////////////////////////////////////////////////////////////////////////
137 // Non-template part of implementation
138 ////////////////////////////////////////////////////////////////////////////////
139 // [TODO]:
140 // - Thread caching of hazptr_rec-s
141 // - Private storage of retired objects
142 // - Control of reclamation (when and by whom)
143 // - Optimized memory order
144
145 /** Definition of default_hazptr_domain() */
146 inline hazptr_domain& default_hazptr_domain() {
147   static hazptr_domain d;
148   DEBUG_PRINT(&d);
149   return d;
150 }
151
152 /** hazptr_rec */
153
154 inline void hazptr_rec::set(const void* p) noexcept {
155   DEBUG_PRINT(this << " " << p);
156   hazptr_.store(p);
157 }
158
159 inline const void* hazptr_rec::get() const noexcept {
160   DEBUG_PRINT(this << " " << hazptr_.load());
161   return hazptr_.load();
162 }
163
164 inline void hazptr_rec::clear() noexcept {
165   DEBUG_PRINT(this);
166   hazptr_.store(nullptr);
167 }
168
169 inline void hazptr_rec::release() noexcept {
170   DEBUG_PRINT(this);
171   clear();
172   active_.store(false);
173 }
174
175 /** hazptr_obj */
176
177 inline const void* hazptr_obj::getObjPtr() const {
178   DEBUG_PRINT(this);
179   return this;
180 }
181
182 /** hazptr_domain */
183
184 inline hazptr_domain::~hazptr_domain() {
185   DEBUG_PRINT(this);
186   { /* reclaim all remaining retired objects */
187     hazptr_obj* next;
188     auto retired = retired_.exchange(nullptr);
189     while (retired) {
190       for (auto p = retired; p; p = next) {
191         next = p->next_;
192         (*(p->reclaim_))(p);
193       }
194       retired = retired_.exchange(nullptr);
195     }
196   }
197   { /* free all hazptr_rec-s */
198     hazptr_rec* next;
199     for (auto p = hazptrs_.load(); p; p = next) {
200       next = p->next_;
201       mr_->deallocate(static_cast<void*>(p), sizeof(hazptr_rec));
202     }
203   }
204 }
205
206 inline void hazptr_domain::try_reclaim() {
207   DEBUG_PRINT(this);
208   rcount_.exchange(0);
209   bulkReclaim();
210 }
211
212 inline hazptr_rec* hazptr_domain::hazptrAcquire() {
213   hazptr_rec* p;
214   hazptr_rec* next;
215   for (p = hazptrs_.load(); p; p = next) {
216     next = p->next_;
217     bool active = p->active_.load();
218     if (!active) {
219       if (p->active_.compare_exchange_weak(active, true)) {
220         DEBUG_PRINT(this << " " << p);
221         return p;
222       }
223     }
224   }
225   p = static_cast<hazptr_rec*>(mr_->allocate(sizeof(hazptr_rec)));
226   if (p == nullptr) {
227     return nullptr;
228   }
229   p->active_.store(true);
230   do {
231     p->next_ = hazptrs_.load();
232     if (hazptrs_.compare_exchange_weak(p->next_, p)) {
233       break;
234     }
235   } while (true);
236   auto hcount = hcount_.fetch_add(1);
237   DEBUG_PRINT(this << " " << p << " " << sizeof(hazptr_rec) << " " << hcount);
238   return p;
239 }
240
241 inline void hazptr_domain::hazptrRelease(hazptr_rec* p) noexcept {
242   DEBUG_PRINT(this << " " << p);
243   p->release();
244 }
245
246 inline int
247 hazptr_domain::pushRetired(hazptr_obj* head, hazptr_obj* tail, int count) {
248   tail->next_ = retired_.load();
249   while (!retired_.compare_exchange_weak(tail->next_, head)) {}
250   return rcount_.fetch_add(count);
251 }
252
253 inline void hazptr_domain::objRetire(hazptr_obj* p) {
254   auto rcount = pushRetired(p, p, 1) + 1;
255   if (rcount >= kScanThreshold * hcount_.load()) {
256     tryBulkReclaim();
257   }
258 }
259
260 inline void hazptr_domain::tryBulkReclaim() {
261   DEBUG_PRINT(this);
262   do {
263     auto hcount = hcount_.load();
264     auto rcount = rcount_.load();
265     if (rcount < kScanThreshold * hcount) {
266       return;
267     }
268     if (rcount_.compare_exchange_weak(rcount, 0)) {
269       break;
270     }
271   } while (true);
272   bulkReclaim();
273 }
274
275 inline void hazptr_domain::bulkReclaim() {
276   DEBUG_PRINT(this);
277   auto p = retired_.exchange(nullptr);
278   auto h = hazptrs_.load();
279   std::unordered_set<const void*> hs;
280   for (; h; h = h->next_) {
281     hs.insert(h->hazptr_.load());
282   }
283   int rcount = 0;
284   hazptr_obj* retired = nullptr;
285   hazptr_obj* tail = nullptr;
286   hazptr_obj* next;
287   for (; p; p = next) {
288     next = p->next_;
289     if (hs.count(p->getObjPtr()) == 0) {
290       DEBUG_PRINT(this << " " << p << " " << p->reclaim_);
291       (*(p->reclaim_))(p);
292     } else {
293       p->next_ = retired;
294       retired = p;
295       if (tail == nullptr) {
296         tail = p;
297       }
298       ++rcount;
299     }
300   }
301   if (tail) {
302     pushRetired(retired, tail, rcount);
303   }
304 }
305
306 } // namespace folly
307 } // namespace hazptr