ReadMostlyMainPtrDeleter
[folly.git] / folly / experimental / TLRefCount.h
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/ThreadLocal.h>
19 #include <folly/experimental/AsymmetricMemoryBarrier.h>
20
21 namespace folly {
22
23 class TLRefCount {
24  public:
25   using Int = int64_t;
26
27   TLRefCount()
28       : localCount_([&]() { return new LocalRefCount(*this); }),
29         collectGuard_(this, [](void*) {}) {}
30
31   ~TLRefCount() noexcept {
32     assert(globalCount_.load() == 0);
33     assert(state_.load() == State::GLOBAL);
34   }
35
36   // This can't increment from 0.
37   Int operator++() noexcept {
38     auto& localCount = *localCount_;
39
40     if (++localCount) {
41       return 42;
42     }
43
44     if (state_.load() == State::GLOBAL_TRANSITION) {
45       std::lock_guard<std::mutex> lg(globalMutex_);
46     }
47
48     assert(state_.load() == State::GLOBAL);
49
50     auto value = globalCount_.load();
51     do {
52       if (value == 0) {
53         return 0;
54       }
55     } while (!globalCount_.compare_exchange_weak(value, value+1));
56
57     return value + 1;
58   }
59
60   Int operator--() noexcept {
61     auto& localCount = *localCount_;
62
63     if (--localCount) {
64       return 42;
65     }
66
67     if (state_.load() == State::GLOBAL_TRANSITION) {
68       std::lock_guard<std::mutex> lg(globalMutex_);
69     }
70
71     assert(state_.load() == State::GLOBAL);
72
73     return globalCount_-- - 1;
74   }
75
76   Int operator*() const {
77     if (state_ != State::GLOBAL) {
78       return 42;
79     }
80     return globalCount_.load();
81   }
82
83   void useGlobal() noexcept {
84     std::array<TLRefCount*, 1> ptrs{{this}};
85     useGlobal(ptrs);
86   }
87
88   template <typename Container>
89   static void useGlobal(const Container& refCountPtrs) {
90     std::vector<std::unique_lock<std::mutex>> lgs_;
91     for (auto refCountPtr : refCountPtrs) {
92       lgs_.emplace_back(refCountPtr->globalMutex_);
93
94       refCountPtr->state_ = State::GLOBAL_TRANSITION;
95     }
96
97     asymmetricHeavyBarrier();
98
99     for (auto refCountPtr : refCountPtrs) {
100       std::weak_ptr<void> collectGuardWeak = refCountPtr->collectGuard_;
101
102       // Make sure we can't create new LocalRefCounts
103       refCountPtr->collectGuard_.reset();
104
105       while (!collectGuardWeak.expired()) {
106         auto accessor = refCountPtr->localCount_.accessAllThreads();
107         for (auto& count : accessor) {
108           count.collect();
109         }
110       }
111
112       refCountPtr->state_ = State::GLOBAL;
113     }
114   }
115
116  private:
117   using AtomicInt = std::atomic<Int>;
118
119   enum class State {
120     LOCAL,
121     GLOBAL_TRANSITION,
122     GLOBAL
123   };
124
125   class LocalRefCount {
126    public:
127     explicit LocalRefCount(TLRefCount& refCount) :
128         refCount_(refCount) {
129       std::lock_guard<std::mutex> lg(refCount.globalMutex_);
130
131       collectGuard_ = refCount.collectGuard_;
132     }
133
134     ~LocalRefCount() {
135       collect();
136     }
137
138     void collect() {
139       std::lock_guard<std::mutex> lg(collectMutex_);
140
141       if (!collectGuard_) {
142         return;
143       }
144
145       collectCount_ = count_.load();
146       refCount_.globalCount_.fetch_add(collectCount_);
147       collectGuard_.reset();
148     }
149
150     bool operator++() {
151       return update(1);
152     }
153
154     bool operator--() {
155       return update(-1);
156     }
157
158    private:
159     bool update(Int delta) {
160       if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
161         return false;
162       }
163
164       // This is equivalent to atomic fetch_add. We know that this operation
165       // is always performed from a single thread. asymmetricLightBarrier()
166       // makes things faster than atomic fetch_add on platforms with native
167       // support.
168       auto count = count_.load(std::memory_order_relaxed) + delta;
169       count_.store(count, std::memory_order_relaxed);
170
171       asymmetricLightBarrier();
172
173       if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
174         std::lock_guard<std::mutex> lg(collectMutex_);
175
176         if (collectGuard_) {
177           return true;
178         }
179         if (collectCount_ != count) {
180           return false;
181         }
182       }
183
184       return true;
185     }
186
187     AtomicInt count_{0};
188     TLRefCount& refCount_;
189
190     std::mutex collectMutex_;
191     Int collectCount_{0};
192     std::shared_ptr<void> collectGuard_;
193   };
194
195   std::atomic<State> state_{State::LOCAL};
196   folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
197   std::atomic<int64_t> globalCount_{1};
198   std::mutex globalMutex_;
199   std::shared_ptr<void> collectGuard_;
200 };
201
202 }