Fix RefCountTest and RCURefCount race
authorAndrii Grynenko <andrii@fb.com>
Fri, 11 Dec 2015 02:16:51 +0000 (18:16 -0800)
committerfacebook-github-bot-0 <folly-bot@fb.com>
Fri, 11 Dec 2015 03:20:23 +0000 (19:20 -0800)
Reviewed By: alikhtarov

Differential Revision: D2741459

fb-gh-sync-id: c4bd068cf735ae25364edba40960096fb35e8c43

folly/experimental/RCURefCount.h
folly/experimental/test/RefCountTest.cpp

index 567006eadf431515c35c58a26e7a2a67a8538268..3814f8199e558e91058ef895b11b01cc34d1de33 100644 (file)
@@ -40,12 +40,13 @@ class RCURefCount {
     auto& localCount = *localCount_;
 
     std::lock_guard<RCUReadLock> lg(RCUReadLock::instance());
     auto& localCount = *localCount_;
 
     std::lock_guard<RCUReadLock> lg(RCUReadLock::instance());
+    auto state = state_.load();
 
 
-    if (LIKELY(state_ == State::LOCAL)) {
+    if (LIKELY(state == State::LOCAL)) {
       ++localCount;
 
       return 42;
       ++localCount;
 
       return 42;
-    } else if (state_ == State::GLOBAL_TRANSITION) {
+    } else if (state == State::GLOBAL_TRANSITION) {
       ++globalCount_;
 
       return 42;
       ++globalCount_;
 
       return 42;
@@ -67,15 +68,16 @@ class RCURefCount {
     auto& localCount = *localCount_;
 
     std::lock_guard<RCUReadLock> lg(RCUReadLock::instance());
     auto& localCount = *localCount_;
 
     std::lock_guard<RCUReadLock> lg(RCUReadLock::instance());
+    auto state = state_.load();
 
 
-    if (LIKELY(state_ == State::LOCAL)) {
+    if (LIKELY(state == State::LOCAL)) {
       --localCount;
 
       return 42;
     } else {
       auto value = --globalCount_;
 
       --localCount;
 
       return 42;
     } else {
       auto value = --globalCount_;
 
-      if (state_ == State::GLOBAL) {
+      if (state == State::GLOBAL) {
         assert(value >= 0);
         return value;
       } else {
         assert(value >= 0);
         return value;
       } else {
index 48ad80d91e396d3a4667de45670645ecf499b110..6cd72474eb13413c5ad9437cce983bbfa8333eca 100644 (file)
@@ -69,7 +69,9 @@ void basicTest() {
   b.wait();
 
   count.useGlobal();
   b.wait();
 
   count.useGlobal();
-  --count;
+  if (--count == 0) {
+    ++got0;
+  }
 
   for (auto& t: ts) {
     t.join();
 
   for (auto& t: ts) {
     t.join();