Refactor ShutdownSocketSet atomic state machine
[folly.git] / folly / io / ShutdownSocketSet.cpp
index a64ae49ae5e9ed8595dbe115d78cbea45cbed4c2..3380559d8e9da46493b7cd9925bcb5e205e8800f 100644 (file)
@@ -51,10 +51,9 @@ void ShutdownSocketSet::add(int fd) {
 
   auto& sref = data_[size_t(fd)];
   uint8_t prevState = FREE;
 
   auto& sref = data_[size_t(fd)];
   uint8_t prevState = FREE;
-  CHECK(sref.compare_exchange_strong(prevState,
-                                     IN_USE,
-                                     std::memory_order_acq_rel))
-    << "Invalid prev state for fd " << fd << ": " << int(prevState);
+  CHECK(sref.compare_exchange_strong(
+      prevState, IN_USE, std::memory_order_relaxed))
+      << "Invalid prev state for fd " << fd << ": " << int(prevState);
 }
 
 void ShutdownSocketSet::remove(int fd) {
 }
 
 void ShutdownSocketSet::remove(int fd) {
@@ -66,23 +65,19 @@ void ShutdownSocketSet::remove(int fd) {
   auto& sref = data_[size_t(fd)];
   uint8_t prevState = 0;
 
   auto& sref = data_[size_t(fd)];
   uint8_t prevState = 0;
 
-retry_load:
   prevState = sref.load(std::memory_order_relaxed);
   prevState = sref.load(std::memory_order_relaxed);
-
-retry:
-  switch (prevState) {
-  case IN_SHUTDOWN:
-    std::this_thread::sleep_for(std::chrono::milliseconds(1));
-    goto retry_load;
-  case FREE:
-    LOG(FATAL) << "Invalid prev state for fd " << fd << ": " << int(prevState);
-  }
-
-  if (!sref.compare_exchange_weak(prevState,
-                                  FREE,
-                                  std::memory_order_acq_rel)) {
-    goto retry;
-  }
+  do {
+    switch (prevState) {
+      case IN_SHUTDOWN:
+        std::this_thread::sleep_for(std::chrono::milliseconds(1));
+        prevState = sref.load(std::memory_order_relaxed);
+        continue;
+      case FREE:
+        LOG(FATAL) << "Invalid prev state for fd " << fd << ": "
+                   << int(prevState);
+    }
+  } while (
+      !sref.compare_exchange_weak(prevState, FREE, std::memory_order_relaxed));
 }
 
 int ShutdownSocketSet::close(int fd) {
 }
 
 int ShutdownSocketSet::close(int fd) {
@@ -95,24 +90,21 @@ int ShutdownSocketSet::close(int fd) {
   uint8_t prevState = sref.load(std::memory_order_relaxed);
   uint8_t newState = 0;
 
   uint8_t prevState = sref.load(std::memory_order_relaxed);
   uint8_t newState = 0;
 
-retry:
-  switch (prevState) {
-  case IN_USE:
-  case SHUT_DOWN:
-    newState = FREE;
-    break;
-  case IN_SHUTDOWN:
-    newState = MUST_CLOSE;
-    break;
-  default:
-    LOG(FATAL) << "Invalid prev state for fd " << fd << ": " << int(prevState);
-  }
-
-  if (!sref.compare_exchange_weak(prevState,
-                                  newState,
-                                  std::memory_order_acq_rel)) {
-    goto retry;
-  }
+  do {
+    switch (prevState) {
+      case IN_USE:
+      case SHUT_DOWN:
+        newState = FREE;
+        break;
+      case IN_SHUTDOWN:
+        newState = MUST_CLOSE;
+        break;
+      default:
+        LOG(FATAL) << "Invalid prev state for fd " << fd << ": "
+                   << int(prevState);
+    }
+  } while (!sref.compare_exchange_weak(
+      prevState, newState, std::memory_order_relaxed));
 
   return newState == FREE ? folly::closeNoInt(fd) : 0;
 }
 
   return newState == FREE ? folly::closeNoInt(fd) : 0;
 }
@@ -126,18 +118,16 @@ void ShutdownSocketSet::shutdown(int fd, bool abortive) {
 
   auto& sref = data_[size_t(fd)];
   uint8_t prevState = IN_USE;
 
   auto& sref = data_[size_t(fd)];
   uint8_t prevState = IN_USE;
-  if (!sref.compare_exchange_strong(prevState,
-                                    IN_SHUTDOWN,
-                                    std::memory_order_acq_rel)) {
+  if (!sref.compare_exchange_strong(
+          prevState, IN_SHUTDOWN, std::memory_order_relaxed)) {
     return;
   }
 
   doShutdown(fd, abortive);
 
   prevState = IN_SHUTDOWN;
     return;
   }
 
   doShutdown(fd, abortive);
 
   prevState = IN_SHUTDOWN;
-  if (sref.compare_exchange_strong(prevState,
-                                   SHUT_DOWN,
-                                   std::memory_order_acq_rel)) {
+  if (sref.compare_exchange_strong(
+          prevState, SHUT_DOWN, std::memory_order_relaxed)) {
     return;
   }
 
     return;
   }
 
@@ -146,16 +136,15 @@ void ShutdownSocketSet::shutdown(int fd, bool abortive) {
 
   folly::closeNoInt(fd);  // ignore errors, nothing to do
 
 
   folly::closeNoInt(fd);  // ignore errors, nothing to do
 
-  CHECK(sref.compare_exchange_strong(prevState,
-                                     FREE,
-                                     std::memory_order_acq_rel))
-    << "Invalid prev state for fd " << fd << ": " << int(prevState);
+  CHECK(
+      sref.compare_exchange_strong(prevState, FREE, std::memory_order_relaxed))
+      << "Invalid prev state for fd " << fd << ": " << int(prevState);
 }
 
 void ShutdownSocketSet::shutdownAll(bool abortive) {
   for (int i = 0; i < maxFd_; ++i) {
     auto& sref = data_[size_t(i)];
 }
 
 void ShutdownSocketSet::shutdownAll(bool abortive) {
   for (int i = 0; i < maxFd_; ++i) {
     auto& sref = data_[size_t(i)];
-    if (sref.load(std::memory_order_acquire) == IN_USE) {
+    if (sref.load(std::memory_order_relaxed) == IN_USE) {
       shutdown(i, abortive);
     }
   }
       shutdown(i, abortive);
     }
   }