fix folly::FunctionScheduler.cancelFunctionAndWait() hanging issue
[folly.git] / folly / experimental / FunctionScheduler.cpp
index d648ce2645181fc1816e25d588f5eaf4f52f9777..2a94501523340d28c82ddc07afbbdee20236cc1e 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -63,7 +63,7 @@ struct PoissonDistributionFunctor {
 
 struct UniformDistributionFunctor {
   std::default_random_engine generator;
-  std::uniform_int_distribution<> dist;
+  std::uniform_int_distribution<milliseconds::rep> dist;
 
   UniformDistributionFunctor(milliseconds minInterval, milliseconds maxInterval)
       : generator(Random::rand32()),
@@ -92,58 +92,88 @@ FunctionScheduler::~FunctionScheduler() {
   shutdown();
 }
 
-void FunctionScheduler::addFunction(const std::function<void()>& cb,
+void FunctionScheduler::addFunction(Function<void()>&& cb,
                                     milliseconds interval,
                                     StringPiece nameID,
                                     milliseconds startDelay) {
-  addFunctionGenericDistribution(
-      cb,
-      IntervalDistributionFunc(ConstIntervalFunctor(interval)),
+  addFunctionInternal(
+      std::move(cb),
+      ConstIntervalFunctor(interval),
       nameID.str(),
       to<std::string>(interval.count(), "ms"),
-      startDelay);
+      startDelay,
+      false /*runOnce*/);
 }
 
-void FunctionScheduler::addFunction(const std::function<void()>& cb,
+void FunctionScheduler::addFunction(Function<void()>&& cb,
                                     milliseconds interval,
                                     const LatencyDistribution& latencyDistr,
                                     StringPiece nameID,
                                     milliseconds startDelay) {
   if (latencyDistr.isPoisson) {
-    addFunctionGenericDistribution(
-        cb,
-        IntervalDistributionFunc(
-            PoissonDistributionFunctor(latencyDistr.poissonMean)),
+    addFunctionInternal(
+        std::move(cb),
+        PoissonDistributionFunctor(latencyDistr.poissonMean),
         nameID.str(),
         to<std::string>(latencyDistr.poissonMean, "ms (Poisson mean)"),
-        startDelay);
+        startDelay,
+        false /*runOnce*/);
   } else {
-    addFunction(cb, interval, nameID, startDelay);
+    addFunction(std::move(cb), interval, nameID, startDelay);
   }
 }
 
+void FunctionScheduler::addFunctionOnce(
+    Function<void()>&& cb,
+    StringPiece nameID,
+    milliseconds startDelay) {
+  addFunctionInternal(
+      std::move(cb),
+      ConstIntervalFunctor(milliseconds::zero()),
+      nameID.str(),
+      "once",
+      startDelay,
+      true /*runOnce*/);
+}
+
 void FunctionScheduler::addFunctionUniformDistribution(
-    const std::function<void()>& cb,
+    Function<void()>&& cb,
     milliseconds minInterval,
     milliseconds maxInterval,
     StringPiece nameID,
     milliseconds startDelay) {
-  addFunctionGenericDistribution(
-      cb,
-      IntervalDistributionFunc(
-          UniformDistributionFunctor(minInterval, maxInterval)),
+  addFunctionInternal(
+      std::move(cb),
+      UniformDistributionFunctor(minInterval, maxInterval),
       nameID.str(),
       to<std::string>(
           "[", minInterval.count(), " , ", maxInterval.count(), "] ms"),
-      startDelay);
+      startDelay,
+      false /*runOnce*/);
 }
 
 void FunctionScheduler::addFunctionGenericDistribution(
-    const std::function<void()>& cb,
-    const IntervalDistributionFunc& intervalFunc,
+    Function<void()>&& cb,
+    IntervalDistributionFunc&& intervalFunc,
     const std::string& nameID,
     const std::string& intervalDescr,
     milliseconds startDelay) {
+  addFunctionInternal(
+      std::move(cb),
+      std::move(intervalFunc),
+      nameID,
+      intervalDescr,
+      startDelay,
+      false /*runOnce*/);
+}
+
+void FunctionScheduler::addFunctionInternal(
+    Function<void()>&& cb,
+    IntervalDistributionFunc&& intervalFunc,
+    const std::string& nameID,
+    const std::string& intervalDescr,
+    milliseconds startDelay,
+    bool runOnce) {
   if (!cb) {
     throw std::invalid_argument(
         "FunctionScheduler: Scheduled function must be set");
@@ -173,16 +203,51 @@ void FunctionScheduler::addFunctionGenericDistribution(
   }
 
   addFunctionToHeap(
-      l, RepeatFunc(cb, intervalFunc, nameID, intervalDescr, startDelay));
+      l,
+      RepeatFunc(
+          std::move(cb),
+          std::move(intervalFunc),
+          nameID,
+          intervalDescr,
+          startDelay,
+          runOnce));
 }
 
-bool FunctionScheduler::cancelFunction(StringPiece nameID) {
-  std::unique_lock<std::mutex> l(mutex_);
-
+bool FunctionScheduler::cancelFunctionWithLock(
+    std::unique_lock<std::mutex>& lock,
+    StringPiece nameID) {
+  CHECK_EQ(lock.owns_lock(), true);
   if (currentFunction_ && currentFunction_->name == nameID) {
     // This function is currently being run. Clear currentFunction_
     // The running thread will see this and won't reschedule the function.
     currentFunction_ = nullptr;
+    cancellingCurrentFunction_ = true;
+    return true;
+  }
+  return false;
+}
+
+bool FunctionScheduler::cancelFunction(StringPiece nameID) {
+  std::unique_lock<std::mutex> l(mutex_);
+
+  if (cancelFunctionWithLock(l, nameID)) {
+    return true;
+  }
+
+  for (auto it = functions_.begin(); it != functions_.end(); ++it) {
+    if (it->isValid() && it->name == nameID) {
+      cancelFunction(l, it);
+      return true;
+    }
+  }
+  return false;
+}
+
+bool FunctionScheduler::cancelFunctionAndWait(StringPiece nameID) {
+  std::unique_lock<std::mutex> l(mutex_);
+
+  if (cancelFunctionWithLock(l, nameID)) {
+    runningCondvar_.wait(l, [this]() { return !cancellingCurrentFunction_; });
     return true;
   }
 
@@ -216,9 +281,27 @@ void FunctionScheduler::cancelFunction(const std::unique_lock<std::mutex>& l,
   }
 }
 
+bool FunctionScheduler::cancelAllFunctionsWithLock(
+    std::unique_lock<std::mutex>& lock) {
+  CHECK_EQ(lock.owns_lock(), true);
+  functions_.clear();
+  if (currentFunction_) {
+    cancellingCurrentFunction_ = true;
+  }
+  currentFunction_ = nullptr;
+  return cancellingCurrentFunction_;
+}
+
 void FunctionScheduler::cancelAllFunctions() {
   std::unique_lock<std::mutex> l(mutex_);
-  functions_.clear();
+  cancelAllFunctionsWithLock(l);
+}
+
+void FunctionScheduler::cancelAllFunctionsAndWait() {
+  std::unique_lock<std::mutex> l(mutex_);
+  if (cancelAllFunctionsWithLock(l)) {
+    runningCondvar_.wait(l, [this]() { return !cancellingCurrentFunction_; });
+  }
 }
 
 bool FunctionScheduler::resetFunctionTimer(StringPiece nameID) {
@@ -253,8 +336,6 @@ bool FunctionScheduler::start() {
     return false;
   }
 
-  running_ = true;
-
   VLOG(1) << "Starting FunctionScheduler with " << functions_.size()
           << " functions.";
   auto now = steady_clock::now();
@@ -269,20 +350,23 @@ bool FunctionScheduler::start() {
   std::make_heap(functions_.begin(), functions_.end(), fnCmp_);
 
   thread_ = std::thread([&] { this->run(); });
+  running_ = true;
+
   return true;
 }
 
-void FunctionScheduler::shutdown() {
+bool FunctionScheduler::shutdown() {
   {
     std::lock_guard<std::mutex> g(mutex_);
     if (!running_) {
-      return;
+      return false;
     }
 
     running_ = false;
     runningCondvar_.notify_one();
   }
   thread_.join();
+  return true;
 }
 
 void FunctionScheduler::run() {
@@ -316,6 +400,7 @@ void FunctionScheduler::run() {
     if (sleepTime < milliseconds::zero()) {
       // We need to run this function now
       runOneFunction(lock, now);
+      runningCondvar_.notify_all();
     } else {
       // Re-add the function to the heap, and wait until we actually
       // need to run it.
@@ -374,6 +459,12 @@ void FunctionScheduler::runOneFunction(std::unique_lock<std::mutex>& lock,
   if (!currentFunction_) {
     // The function was cancelled while we were running it.
     // We shouldn't reschedule it;
+    cancellingCurrentFunction_ = false;
+    return;
+  }
+  if (currentFunction_->runOnce) {
+    // Don't reschedule if the function only needed to run once.
+    currentFunction_ = nullptr;
     return;
   }
   // Clear currentFunction_