allow command to accept "--" separator
[folly.git] / folly / experimental / TLRefCount.h
1 /*
2  * Copyright 2015-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/ThreadLocal.h>
19 #include <folly/synchronization/AsymmetricMemoryBarrier.h>
20
21 namespace folly {
22
23 class TLRefCount {
24  public:
25   using Int = int64_t;
26
27   TLRefCount()
28       : localCount_([&]() { return new LocalRefCount(*this); }),
29         collectGuard_(this, [](void*) {}) {}
30
31   ~TLRefCount() noexcept {
32     assert(globalCount_.load() == 0);
33     assert(state_.load() == State::GLOBAL);
34   }
35
36   // This can't increment from 0.
37   Int operator++() noexcept {
38     auto& localCount = *localCount_;
39
40     if (++localCount) {
41       return 42;
42     }
43
44     if (state_.load() == State::GLOBAL_TRANSITION) {
45       std::lock_guard<std::mutex> lg(globalMutex_);
46     }
47
48     assert(state_.load() == State::GLOBAL);
49
50     auto value = globalCount_.load();
51     do {
52       if (value == 0) {
53         return 0;
54       }
55     } while (!globalCount_.compare_exchange_weak(value, value+1));
56
57     return value + 1;
58   }
59
60   Int operator--() noexcept {
61     auto& localCount = *localCount_;
62
63     if (--localCount) {
64       return 42;
65     }
66
67     if (state_.load() == State::GLOBAL_TRANSITION) {
68       std::lock_guard<std::mutex> lg(globalMutex_);
69     }
70
71     assert(state_.load() == State::GLOBAL);
72
73     return globalCount_-- - 1;
74   }
75
76   Int operator*() const {
77     if (state_ != State::GLOBAL) {
78       return 42;
79     }
80     return globalCount_.load();
81   }
82
83   void useGlobal() noexcept {
84     std::array<TLRefCount*, 1> ptrs{{this}};
85     useGlobal(ptrs);
86   }
87
88   template <typename Container>
89   static void useGlobal(const Container& refCountPtrs) {
90 #ifdef FOLLY_SANITIZE_THREAD
91     // TSAN has a limitation for the number of locks held concurrently, so it's
92     // safer to call useGlobal() serially.
93     if (refCountPtrs.size() > 1) {
94       for (auto refCountPtr : refCountPtrs) {
95         refCountPtr->useGlobal();
96       }
97       return;
98     }
99 #endif
100
101     std::vector<std::unique_lock<std::mutex>> lgs_;
102     for (auto refCountPtr : refCountPtrs) {
103       lgs_.emplace_back(refCountPtr->globalMutex_);
104
105       refCountPtr->state_ = State::GLOBAL_TRANSITION;
106     }
107
108     asymmetricHeavyBarrier();
109
110     for (auto refCountPtr : refCountPtrs) {
111       std::weak_ptr<void> collectGuardWeak = refCountPtr->collectGuard_;
112
113       // Make sure we can't create new LocalRefCounts
114       refCountPtr->collectGuard_.reset();
115
116       while (!collectGuardWeak.expired()) {
117         auto accessor = refCountPtr->localCount_.accessAllThreads();
118         for (auto& count : accessor) {
119           count.collect();
120         }
121       }
122
123       refCountPtr->state_ = State::GLOBAL;
124     }
125   }
126
127  private:
128   using AtomicInt = std::atomic<Int>;
129
130   enum class State {
131     LOCAL,
132     GLOBAL_TRANSITION,
133     GLOBAL
134   };
135
136   class LocalRefCount {
137    public:
138     explicit LocalRefCount(TLRefCount& refCount) :
139         refCount_(refCount) {
140       std::lock_guard<std::mutex> lg(refCount.globalMutex_);
141
142       collectGuard_ = refCount.collectGuard_;
143     }
144
145     ~LocalRefCount() {
146       collect();
147     }
148
149     void collect() {
150       std::lock_guard<std::mutex> lg(collectMutex_);
151
152       if (!collectGuard_) {
153         return;
154       }
155
156       collectCount_ = count_.load();
157       refCount_.globalCount_.fetch_add(collectCount_);
158       collectGuard_.reset();
159     }
160
161     bool operator++() {
162       return update(1);
163     }
164
165     bool operator--() {
166       return update(-1);
167     }
168
169    private:
170     bool update(Int delta) {
171       if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
172         return false;
173       }
174
175       // This is equivalent to atomic fetch_add. We know that this operation
176       // is always performed from a single thread. asymmetricLightBarrier()
177       // makes things faster than atomic fetch_add on platforms with native
178       // support.
179       auto count = count_.load(std::memory_order_relaxed) + delta;
180       count_.store(count, std::memory_order_relaxed);
181
182       asymmetricLightBarrier();
183
184       if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
185         std::lock_guard<std::mutex> lg(collectMutex_);
186
187         if (collectGuard_) {
188           return true;
189         }
190         if (collectCount_ != count) {
191           return false;
192         }
193       }
194
195       return true;
196     }
197
198     AtomicInt count_{0};
199     TLRefCount& refCount_;
200
201     std::mutex collectMutex_;
202     Int collectCount_{0};
203     std::shared_ptr<void> collectGuard_;
204   };
205
206   std::atomic<State> state_{State::LOCAL};
207   folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
208   std::atomic<int64_t> globalCount_{1};
209   std::mutex globalMutex_;
210   std::shared_ptr<void> collectGuard_;
211 };
212
213 } // namespace folly