9667f4c100cb5b1d62c188a6b9c05585547c2c49
[folly.git] / folly / experimental / TLRefCount.h
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/ThreadLocal.h>
19
20 namespace folly {
21
22 class TLRefCount {
23  public:
24   using Int = int64_t;
25
26   TLRefCount()
27       : localCount_([&]() { return new LocalRefCount(*this); }),
28         collectGuard_(this, [](void*) {}) {}
29
30   ~TLRefCount() noexcept {
31     assert(globalCount_.load() == 0);
32     assert(state_.load() == State::GLOBAL);
33   }
34
35   // This can't increment from 0.
36   Int operator++() noexcept {
37     auto& localCount = *localCount_;
38
39     if (++localCount) {
40       return 42;
41     }
42
43     if (state_.load() == State::GLOBAL_TRANSITION) {
44       std::lock_guard<std::mutex> lg(globalMutex_);
45     }
46
47     assert(state_.load() == State::GLOBAL);
48
49     auto value = globalCount_.load();
50     do {
51       if (value == 0) {
52         return 0;
53       }
54     } while (!globalCount_.compare_exchange_weak(value, value+1));
55
56     return value + 1;
57   }
58
59   Int operator--() noexcept {
60     auto& localCount = *localCount_;
61
62     if (--localCount) {
63       return 42;
64     }
65
66     if (state_.load() == State::GLOBAL_TRANSITION) {
67       std::lock_guard<std::mutex> lg(globalMutex_);
68     }
69
70     assert(state_.load() == State::GLOBAL);
71
72     return globalCount_-- - 1;
73   }
74
75   Int operator*() const {
76     if (state_ != State::GLOBAL) {
77       return 42;
78     }
79     return globalCount_.load();
80   }
81
82   void useGlobal() noexcept {
83     std::lock_guard<std::mutex> lg(globalMutex_);
84
85     state_ = State::GLOBAL_TRANSITION;
86
87     std::weak_ptr<void> collectGuardWeak = collectGuard_;
88
89     // Make sure we can't create new LocalRefCounts
90     collectGuard_.reset();
91
92     while (!collectGuardWeak.expired()) {
93       auto accessor = localCount_.accessAllThreads();
94       for (auto& count : accessor) {
95         count.collect();
96       }
97     }
98
99     state_ = State::GLOBAL;
100   }
101
102  private:
103   using AtomicInt = std::atomic<Int>;
104
105   enum class State {
106     LOCAL,
107     GLOBAL_TRANSITION,
108     GLOBAL
109   };
110
111   class LocalRefCount {
112    public:
113     explicit LocalRefCount(TLRefCount& refCount) :
114         refCount_(refCount) {
115       std::lock_guard<std::mutex> lg(refCount.globalMutex_);
116
117       collectGuard_ = refCount.collectGuard_;
118     }
119
120     ~LocalRefCount() {
121       collect();
122     }
123
124     void collect() {
125       std::lock_guard<std::mutex> lg(collectMutex_);
126
127       if (!collectGuard_) {
128         return;
129       }
130
131       collectCount_ = count_.load();
132       refCount_.globalCount_.fetch_add(collectCount_);
133       collectGuard_.reset();
134     }
135
136     bool operator++() {
137       return update(1);
138     }
139
140     bool operator--() {
141       return update(-1);
142     }
143
144    private:
145     bool update(Int delta) {
146       if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
147         return false;
148       }
149
150       auto count = count_ += delta;
151
152       if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
153         std::lock_guard<std::mutex> lg(collectMutex_);
154
155         if (collectGuard_) {
156           return true;
157         }
158         if (collectCount_ != count) {
159           return false;
160         }
161       }
162
163       return true;
164     }
165
166     AtomicInt count_{0};
167     TLRefCount& refCount_;
168
169     std::mutex collectMutex_;
170     Int collectCount_{0};
171     std::shared_ptr<void> collectGuard_;
172   };
173
174   std::atomic<State> state_{State::LOCAL};
175   folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
176   std::atomic<int64_t> globalCount_{1};
177   std::mutex globalMutex_;
178   std::shared_ptr<void> collectGuard_;
179 };
180
181 }