(FSM) updateState with unprotected action
authorHans Fugal <fugalh@fb.com>
Mon, 20 Oct 2014 19:07:36 +0000 (12:07 -0700)
committerdcsommer <dcsommer@fb.com>
Wed, 29 Oct 2014 23:02:43 +0000 (16:02 -0700)
Summary:
Like the magic macros. As it says in the comment, this can lead to nicer code.
Also added/moved a couple examples to the top of `test/FSM.cpp`.

Test Plan: new tests

Reviewed By: davejwatson@fb.com

Subscribers: trunkagent, net-systems@, fugalh, exa, njormrod

FB internal diff: D1618184

folly/wangle/detail/FSM.h
folly/wangle/test/FSM.cpp

index 96c6c280414115d87cf94c58e77a841fb78a7b54..95a0bcdef92f0757f02637d1906405026e5467b4 100644 (file)
@@ -41,16 +41,14 @@ private:
   std::atomic<Enum> state_;
 
 public:
-  FSM(Enum startState) : state_(startState) {}
+  explicit FSM(Enum startState) : state_(startState) {}
 
   Enum getState() const {
     return state_.load(std::memory_order_relaxed);
   }
 
-  // transition from state A to state B, and then perform action while the
-  // lock is still held.
-  //
-  // If the current state is not A, returns false.
+  /// Atomically do a state transition with accompanying action.
+  /// @returns true on success, false and action unexecuted otherwise
   template <class F>
   bool updateState(Enum A, Enum B, F const& action) {
     std::lock_guard<Mutex> lock(mutex_);
@@ -59,22 +57,62 @@ public:
     action();
     return true;
   }
+
+  /// Atomically do a state transition with accompanying action. Then do the
+  /// unprotected action without holding the lock. If the atomic transition
+  /// fails, returns false and neither action was executed.
+  ///
+  /// This facilitates code like this:
+  ///   bool done = false;
+  ///   while (!done) {
+  ///     switch (getState()) {
+  ///     case State::Foo:
+  ///       done = updateState(State::Foo, State::Bar,
+  ///           [&]{ /* do protected stuff */ },
+  ///           [&]{ /* do unprotected stuff */});
+  ///       break;
+  ///
+  /// Which reads nicer than code like this:
+  ///   while (true) {
+  ///     switch (getState()) {
+  ///     case State::Foo:
+  ///       if (!updateState(State::Foo, State::Bar,
+  ///           [&]{ /* do protected stuff */ })) {
+  ///         continue;
+  ///       }
+  ///       /* do unprotected stuff */
+  ///       return; // or otherwise break out of the loop
+  template <class F1, class F2>
+  bool updateState(Enum A, Enum B,
+                   F1 const& protectedAction, F2 const& unprotectedAction) {
+    bool result = updateState(A, B, protectedAction);
+    if (result) {
+      unprotectedAction();
+    }
+    return result;
+  }
 };
 
 #define FSM_START \
-  retry: \
-    switch (getState()) {
+  {bool done = false; while (!done) { auto state = getState(); switch (state) {
 
-#define FSM_UPDATE2(a, b, action, unlocked_code) \
-    case a: \
-      if (!updateState((a), (b), (action))) goto retry; \
-      { unlocked_code ; } \
-      break;
+#define FSM_UPDATE2(b, protectedAction, unprotectedAction) \
+    done = updateState(state, (b), (protectedAction), (unprotectedAction));
 
-#define FSM_UPDATE(a, b, action) FSM_UPDATE2((a), (b), (action), {})
+#define FSM_UPDATE(b, action) FSM_UPDATE2((b), (action), []{})
 
-#define FSM_END \
-    }
+#define FSM_CASE(a, b, action) \
+  case (a): \
+    FSM_UPDATE((b), (action)); \
+    break;
+
+#define FSM_CASE2(a, b, protectedAction, unprotectedAction) \
+  case (a): \
+    FSM_UPDATE2((b), (protectedAction), (unprotectedAction)); \
+    break;
+
+#define FSM_BREAK done = true; break;
+#define FSM_END }}}
 
 
 }}}
index 85fc5d0540bbc60b04a9c2530714d036d87ea67e..1ce78c2ee7779279c0fd8c957378f2431ac811ea 100644 (file)
@@ -21,44 +21,45 @@ using namespace folly::wangle::detail;
 
 enum class State { A, B };
 
-TEST(FSM, ctor) {
+TEST(FSM, example) {
   FSM<State> fsm(State::A);
-  EXPECT_EQ(State::A, fsm.getState());
-}
+  int count = 0;
+  int unprotectedCount = 0;
+
+  // somebody set up us the switch
+  auto tryTransition = [&]{
+    switch (fsm.getState()) {
+    case State::A:
+      return fsm.updateState(State::A, State::B, [&]{ count++; });
+    case State::B:
+      return fsm.updateState(State::B, State::A,
+                             [&]{ count--; }, [&]{ unprotectedCount--; });
+    }
+    return false; // unreachable
+  };
 
-TEST(FSM, update) {
-  FSM<State> fsm(State::A);
-  EXPECT_TRUE(fsm.updateState(State::A, State::B, []{}));
+  // keep retrying until success (like a cas)
+  while (!tryTransition()) ;
   EXPECT_EQ(State::B, fsm.getState());
-}
-
-TEST(FSM, badUpdate) {
-  FSM<State> fsm(State::A);
-  EXPECT_FALSE(fsm.updateState(State::B, State::A, []{}));
-}
-
-TEST(FSM, actionOnUpdate) {
-  FSM<State> fsm(State::A);
-  size_t count = 0;
-  fsm.updateState(State::A, State::B, [&]{ count++; });
   EXPECT_EQ(1, count);
-}
+  EXPECT_EQ(0, unprotectedCount);
 
-TEST(FSM, noActionOnBadUpdate) {
-  FSM<State> fsm(State::A);
-  size_t count = 0;
-  fsm.updateState(State::B, State::A, [&]{ count++; });
+  while (!tryTransition()) ;
+  EXPECT_EQ(State::A, fsm.getState());
   EXPECT_EQ(0, count);
+  EXPECT_EQ(-1, unprotectedCount);
 }
 
-TEST(FSM, magicMacros) {
+TEST(FSM, magicMacrosExample) {
   struct MyFSM : public FSM<State> {
-    size_t count = 0;
+    int count = 0;
+    int unprotectedCount = 0;
     MyFSM() : FSM<State>(State::A) {}
     void twiddle() {
       FSM_START
-        FSM_UPDATE(State::A, State::B, [&]{ count++; });
-        FSM_UPDATE(State::B, State::A, [&]{ count--; });
+        FSM_CASE(State::A, State::B, [&]{ count++; });
+        FSM_CASE2(State::B, State::A,
+                  [&]{ count--; }, [&]{ unprotectedCount--; });
       FSM_END
     }
   };
@@ -68,34 +69,47 @@ TEST(FSM, magicMacros) {
   fsm.twiddle();
   EXPECT_EQ(State::B, fsm.getState());
   EXPECT_EQ(1, fsm.count);
+  EXPECT_EQ(0, fsm.unprotectedCount);
 
   fsm.twiddle();
   EXPECT_EQ(State::A, fsm.getState());
   EXPECT_EQ(0, fsm.count);
+  EXPECT_EQ(-1, fsm.unprotectedCount);
 }
 
-TEST(FSM, magicMacros2) {
-  struct MyFSM : public FSM<State> {
-    size_t count = 0;
-    size_t count2 = 0;
-    MyFSM() : FSM<State>(State::A) {}
-    void twiddle() {
-      FSM_START
-        FSM_UPDATE2(State::A, State::B, [&]{ count++; }, count2++);
-        FSM_UPDATE2(State::B, State::A, [&]{ count--; }, count2--);
-      FSM_END
-    }
-  };
 
-  MyFSM fsm;
+TEST(FSM, ctor) {
+  FSM<State> fsm(State::A);
+  EXPECT_EQ(State::A, fsm.getState());
+}
 
-  fsm.twiddle();
+TEST(FSM, update) {
+  FSM<State> fsm(State::A);
+  EXPECT_TRUE(fsm.updateState(State::A, State::B, []{}));
   EXPECT_EQ(State::B, fsm.getState());
-  EXPECT_EQ(1, fsm.count);
-  EXPECT_EQ(1, fsm.count2);
+}
 
-  fsm.twiddle();
-  EXPECT_EQ(State::A, fsm.getState());
-  EXPECT_EQ(0, fsm.count);
-  EXPECT_EQ(0, fsm.count2);
+TEST(FSM, badUpdate) {
+  FSM<State> fsm(State::A);
+  EXPECT_FALSE(fsm.updateState(State::B, State::A, []{}));
+}
+
+TEST(FSM, actionOnUpdate) {
+  FSM<State> fsm(State::A);
+  int count = 0;
+  fsm.updateState(State::A, State::B, [&]{ count++; });
+  EXPECT_EQ(1, count);
+}
+
+TEST(FSM, noActionOnBadUpdate) {
+  FSM<State> fsm(State::A);
+  int count = 0;
+  fsm.updateState(State::B, State::A, [&]{ count++; });
+  EXPECT_EQ(0, count);
+}
+
+TEST(FSM, stateTransitionBeforeAction) {
+  FSM<State> fsm(State::A);
+  fsm.updateState(State::A, State::B,
+                  [&]{ EXPECT_EQ(State::B, fsm.getState()); });
 }