7a06d47c2bb8f80a18d88ee432feea7a483450fe
[folly.git] / folly / experimental / TLRefCount.h
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/Baton.h>
19 #include <folly/ThreadLocal.h>
20
21 namespace folly {
22
23 class TLRefCount {
24  public:
25   using Int = int64_t;
26
27   TLRefCount() :
28       localCount_([&]() {
29           return new LocalRefCount(*this);
30         }),
31       collectGuard_(&collectBaton_, [](void* p) {
32           auto baton = reinterpret_cast<folly::Baton<>*>(p);
33           baton->post();
34         }) {
35   }
36
37   ~TLRefCount() noexcept {
38     assert(globalCount_.load() == 0);
39     assert(state_.load() == State::GLOBAL);
40   }
41
42   // This can't increment from 0.
43   Int operator++() noexcept {
44     auto& localCount = *localCount_;
45
46     if (++localCount) {
47       return 42;
48     }
49
50     if (state_.load() == State::GLOBAL_TRANSITION) {
51       std::lock_guard<std::mutex> lg(globalMutex_);
52     }
53
54     assert(state_.load() == State::GLOBAL);
55
56     auto value = globalCount_.load();
57     do {
58       if (value == 0) {
59         return 0;
60       }
61     } while (!globalCount_.compare_exchange_weak(value, value+1));
62
63     return value + 1;
64   }
65
66   Int operator--() noexcept {
67     auto& localCount = *localCount_;
68
69     if (--localCount) {
70       return 42;
71     }
72
73     if (state_.load() == State::GLOBAL_TRANSITION) {
74       std::lock_guard<std::mutex> lg(globalMutex_);
75     }
76
77     assert(state_.load() == State::GLOBAL);
78
79     return --globalCount_;
80   }
81
82   Int operator*() const {
83     if (state_ != State::GLOBAL) {
84       return 42;
85     }
86     return globalCount_.load();
87   }
88
89   void useGlobal() noexcept {
90     std::lock_guard<std::mutex> lg(globalMutex_);
91
92     state_ = State::GLOBAL_TRANSITION;
93
94     auto accessor = localCount_.accessAllThreads();
95     for (auto& count : accessor) {
96       count.collect();
97     }
98
99     collectGuard_.reset();
100     collectBaton_.wait();
101
102     state_ = State::GLOBAL;
103   }
104
105  private:
106   using AtomicInt = std::atomic<Int>;
107
108   enum class State {
109     LOCAL,
110     GLOBAL_TRANSITION,
111     GLOBAL
112   };
113
114   class LocalRefCount {
115    public:
116     explicit LocalRefCount(TLRefCount& refCount) :
117         refCount_(refCount) {
118       std::lock_guard<std::mutex> lg(refCount.globalMutex_);
119
120       collectGuard_ = refCount.collectGuard_;
121     }
122
123     ~LocalRefCount() {
124       collect();
125     }
126
127     void collect() {
128       std::lock_guard<std::mutex> lg(collectMutex_);
129
130       if (!collectGuard_) {
131         return;
132       }
133
134       collectCount_ = count_;
135       refCount_.globalCount_ += collectCount_;
136       collectGuard_.reset();
137     }
138
139     bool operator++() {
140       return update(1);
141     }
142
143     bool operator--() {
144       return update(-1);
145     }
146
147    private:
148     bool update(Int delta) {
149       if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
150         return false;
151       }
152
153       auto count = count_ += delta;
154
155       if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
156         std::lock_guard<std::mutex> lg(collectMutex_);
157
158         if (collectGuard_) {
159           return true;
160         }
161         if (collectCount_ != count) {
162           return false;
163         }
164       }
165
166       return true;
167     }
168
169     Int count_{0};
170     TLRefCount& refCount_;
171
172     std::mutex collectMutex_;
173     Int collectCount_{0};
174     std::shared_ptr<void> collectGuard_;
175   };
176
177   std::atomic<State> state_{State::LOCAL};
178   folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
179   std::atomic<int64_t> globalCount_{1};
180   std::mutex globalMutex_;
181   folly::Baton<> collectBaton_;
182   std::shared_ptr<void> collectGuard_;
183 };
184
185 }