user-defined expirations
authorJames Sedgwick <jsedgwick@fb.com>
Tue, 23 Sep 2014 18:17:03 +0000 (11:17 -0700)
committerAnton Likhtarov <alikhtarov@fb.com>
Fri, 26 Sep 2014 22:22:51 +0000 (15:22 -0700)
Summary:
Couple of notes:
1. is it a bummer not to have per-task callbacks of some kind? the interfaces set up here only tell you that some task expired, not which one expired. TM calls back with the Runnable object. is that useful?
2. std::chrono::* business is frustratingly verbose, but the safety/explicitness is nice. Not sure how I feel overall.
3. perhaps expirations should be given in microseconds even if we don't think we can accurately accomplish that

Test Plan: added unit

Reviewed By: hans@fb.com

Subscribers: fugalh, njormrod, bmatheny

FB internal diff: D1563520

folly/experimental/wangle/concurrent/CPUThreadPoolExecutor.cpp
folly/experimental/wangle/concurrent/CPUThreadPoolExecutor.h
folly/experimental/wangle/concurrent/IOThreadPoolExecutor.cpp
folly/experimental/wangle/concurrent/IOThreadPoolExecutor.h
folly/experimental/wangle/concurrent/ThreadPoolExecutor.cpp
folly/experimental/wangle/concurrent/ThreadPoolExecutor.h
folly/experimental/wangle/concurrent/test/ThreadPoolExecutorTest.cpp

index ca88e5804ad6130c437b8c9cae71c887ca71b9ba..daac2eb90684bae1b40bf35f653e1cebc0db46ae 100644 (file)
@@ -36,13 +36,20 @@ CPUThreadPoolExecutor::~CPUThreadPoolExecutor() {
 }
 
 void CPUThreadPoolExecutor::add(Func func) {
+  add(std::move(func), std::chrono::milliseconds(0));
+}
+
+void CPUThreadPoolExecutor::add(
+    Func func,
+    std::chrono::milliseconds expiration,
+    Func expireCallback) {
   // TODO handle enqueue failure, here and in other add() callsites
-  taskQueue_->add(CPUTask(std::move(func)));
+  taskQueue_->add(
+      CPUTask(std::move(func), expiration, std::move(expireCallback)));
 }
 
 void CPUThreadPoolExecutor::threadRun(std::shared_ptr<Thread> thread) {
   while (1) {
-    // TODO expiration / codel
     auto task = taskQueue_->take();
     if (UNLIKELY(task.poison)) {
       CHECK(threadsToStop_-- > 0);
index 7811c6783f81c2b1a660c8cdf018f6a6a4ca9618..28e2dad6e85b9b0850ed9dcf43254962d702a49c 100644 (file)
@@ -34,11 +34,22 @@ class CPUThreadPoolExecutor : public ThreadPoolExecutor {
   ~CPUThreadPoolExecutor();
 
   void add(Func func) override;
+  void add(
+      Func func,
+      std::chrono::milliseconds expiration,
+      Func expireCallback = nullptr) override;
 
   struct CPUTask : public ThreadPoolExecutor::Task {
     // Must be noexcept move constructible so it can be used in MPMCQueue
-    explicit CPUTask(Func&& f) : Task(std::move(f)), poison(false) {}
-    CPUTask() : Task(nullptr), poison(true) {}
+    explicit CPUTask(
+        Func&& f,
+        std::chrono::milliseconds expiration,
+        Func&& expireCallback)
+      : Task(std::move(f), expiration, std::move(expireCallback)),
+        poison(false) {}
+    CPUTask()
+      : Task(nullptr, std::chrono::milliseconds(0), nullptr),
+        poison(true) {}
     CPUTask(CPUTask&& o) noexcept : Task(std::move(o)), poison(o.poison) {}
     CPUTask(const CPUTask&) = default;
     CPUTask& operator=(const CPUTask&) = default;
index 6e106f92836071c0755a534984bc70f83631b627..80d5ef73164f3aa41e343ba930dc0147bb6a2f38 100644 (file)
@@ -35,6 +35,13 @@ IOThreadPoolExecutor::~IOThreadPoolExecutor() {
 }
 
 void IOThreadPoolExecutor::add(Func func) {
+  add(std::move(func), std::chrono::milliseconds(0));
+}
+
+void IOThreadPoolExecutor::add(
+    Func func,
+    std::chrono::milliseconds expiration,
+    Func expireCallback) {
   RWSpinLock::ReadHolder{&threadListLock_};
   if (threadList_.get().empty()) {
     throw std::runtime_error("No threads available");
@@ -42,7 +49,8 @@ void IOThreadPoolExecutor::add(Func func) {
   auto thread = threadList_.get()[nextThread_++ % threadList_.get().size()];
   auto ioThread = std::static_pointer_cast<IOThread>(thread);
 
-  auto moveTask = folly::makeMoveWrapper(Task(std::move(func)));
+  auto moveTask = folly::makeMoveWrapper(
+      Task(std::move(func), expiration, std::move(expireCallback)));
   auto wrappedFunc = [this, ioThread, moveTask] () mutable {
     runTask(ioThread, std::move(*moveTask));
     ioThread->pendingTasks--;
index c42da7198f56a5761690fdc4096ba936957dfd0a..60f9d9332b5860cc716b764dc1af310532833d63 100644 (file)
@@ -30,6 +30,10 @@ class IOThreadPoolExecutor : public ThreadPoolExecutor {
   ~IOThreadPoolExecutor();
 
   void add(Func func) override;
+  void add(
+      Func func,
+      std::chrono::milliseconds expiration,
+      Func expireCallback = nullptr) override;
 
  private:
   ThreadPtr makeThread() override;
index 30e46f5c2c6326cdd14c2b036481611881cb59d4..8b0b158dd4a2b7a84f96e1ef1eff90b35cb2c0ed 100644 (file)
@@ -27,23 +27,43 @@ ThreadPoolExecutor::~ThreadPoolExecutor() {
   CHECK(threadList_.get().size() == 0);
 }
 
+ThreadPoolExecutor::Task::Task(
+    Func&& func,
+    std::chrono::milliseconds expiration,
+    Func&& expireCallback)
+    : func_(std::move(func)),
+      expiration_(expiration),
+      expireCallback_(std::move(expireCallback)) {
+  // Assume that the task in enqueued on creation
+  enqueueTime_ = std::chrono::steady_clock::now();
+}
+
 void ThreadPoolExecutor::runTask(
     const ThreadPtr& thread,
     Task&& task) {
   thread->idle = false;
-  task.started();
-  try {
-    task.func();
-  } catch (const std::exception& e) {
-    LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled " <<
-                  typeid(e).name() << " exception: " << e.what();
-  } catch (...) {
-    LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled non-exception "
-                  "object";
+  auto startTime = std::chrono::steady_clock::now();
+  task.stats_.waitTime = startTime - task.enqueueTime_;
+  if (task.expiration_ > std::chrono::milliseconds(0) &&
+      task.stats_.waitTime >= task.expiration_) {
+    task.stats_.expired = true;
+    if (task.expireCallback_ != nullptr) {
+      task.expireCallback_();
+    }
+  } else {
+    try {
+      task.func_();
+    } catch (const std::exception& e) {
+      LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled " <<
+                    typeid(e).name() << " exception: " << e.what();
+    } catch (...) {
+      LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled non-exception "
+                    "object";
+    }
+    task.stats_.runTime = std::chrono::steady_clock::now() - startTime;
   }
-  task.completed();
-  taskStatsSubject_.onNext(std::move(task.stats));
   thread->idle = true;
+  taskStatsSubject_.onNext(std::move(task.stats_));
 }
 
 size_t ThreadPoolExecutor::numThreads() {
index 4eda2d36e3039fd703aeb254295c82b6629df015..bf0dfda89cbfff54f19cc67f4f7a79ec041bcd98 100644 (file)
@@ -38,6 +38,12 @@ class ThreadPoolExecutor : public experimental::Executor {
 
   ~ThreadPoolExecutor();
 
+  virtual void add(Func func) override = 0;
+  virtual void add(
+      Func func,
+      std::chrono::milliseconds expiration,
+      Func expireCallback) = 0;
+
   size_t numThreads();
   void setNumThreads(size_t numThreads);
   void stop();
@@ -55,8 +61,8 @@ class ThreadPoolExecutor : public experimental::Executor {
   struct TaskStats {
     TaskStats() : expired(false), waitTime(0), runTime(0) {}
     bool expired;
-    std::chrono::microseconds waitTime;
-    std::chrono::microseconds runTime;
+    std::chrono::nanoseconds waitTime;
+    std::chrono::nanoseconds runTime;
   };
 
   Subscription subscribeToTaskStats(
@@ -82,27 +88,15 @@ class ThreadPoolExecutor : public experimental::Executor {
   typedef std::shared_ptr<Thread> ThreadPtr;
 
   struct Task {
-    explicit Task(Func&& f) : func(std::move(f)) {
-      // Assume that the task in enqueued on creation
-      intervalBegin = std::chrono::steady_clock::now();
-    }
-
-    Func func;
-    TaskStats stats;
-    // TODO per-task timeouts, expirations
-
-    void started() {
-      auto now = std::chrono::steady_clock::now();
-      stats.waitTime = std::chrono::duration_cast<std::chrono::microseconds>(
-          now - intervalBegin);
-      intervalBegin = now;
-    }
-    void completed() {
-      stats.runTime = std::chrono::duration_cast<std::chrono::microseconds>(
-         std::chrono::steady_clock::now() - intervalBegin);
-    }
-
-    std::chrono::steady_clock::time_point intervalBegin;
+    explicit Task(
+        Func&& func,
+        std::chrono::milliseconds expiration,
+        Func&& expireCallback);
+    Func func_;
+    TaskStats stats_;
+    std::chrono::steady_clock::time_point enqueueTime_;
+    std::chrono::milliseconds expiration_;
+    Func expireCallback_;
   };
 
   void runTask(const ThreadPtr& thread, Task&& task);
index eb8527ca0ab5637758b88cb1493fac88830e3374..8b972773c07d04c1f5a07f2e23b73220d7aefd6a 100644 (file)
 #include <gtest/gtest.h>
 
 using namespace folly::wangle;
+using namespace std::chrono;
 
 static Func burnMs(uint64_t ms) {
-  return [ms]() { std::this_thread::sleep_for(std::chrono::milliseconds(ms)); };
+  return [ms]() { std::this_thread::sleep_for(milliseconds(ms)); };
 }
 
 template <class TPE>
@@ -176,11 +177,11 @@ static void taskStats() {
       [&] (ThreadPoolExecutor::TaskStats stats) {
         int i = c++;
         if (i < 10) {
-          EXPECT_GE(10000, stats.waitTime.count());
-          EXPECT_LE(20000, stats.runTime.count());
+          EXPECT_GE(milliseconds(10), stats.waitTime);
+          EXPECT_LE(milliseconds(20), stats.runTime);
         } else {
-          EXPECT_LE(10000, stats.waitTime.count());
-          EXPECT_LE(10000, stats.runTime.count());
+          EXPECT_LE(milliseconds(10), stats.waitTime);
+          EXPECT_LE(milliseconds(10), stats.runTime);
         }
       }));
   for (int i = 0; i < 10; i++) {
@@ -200,3 +201,35 @@ TEST(ThreadPoolExecutorTest, CPUTaskStats) {
 TEST(ThreadPoolExecutorTest, IOTaskStats) {
   taskStats<IOThreadPoolExecutor>();
 }
+
+template <class TPE>
+static void expiration() {
+  TPE tpe(1);
+  std::atomic<int> statCbCount(0);
+  tpe.subscribeToTaskStats(Observer<ThreadPoolExecutor::TaskStats>::create(
+      [&] (ThreadPoolExecutor::TaskStats stats) {
+        int i = statCbCount++;
+        if (i == 0) {
+          EXPECT_FALSE(stats.expired);
+        } else if (i == 1) {
+          EXPECT_TRUE(stats.expired);
+        } else {
+          FAIL();
+        }
+      }));
+  std::atomic<int> expireCbCount(0);
+  auto expireCb = [&] () { expireCbCount++; };
+  tpe.add(burnMs(10), milliseconds(10), expireCb);
+  tpe.add(burnMs(10), milliseconds(10), expireCb);
+  tpe.join();
+  EXPECT_EQ(2, statCbCount);
+  EXPECT_EQ(1, expireCbCount);
+}
+
+TEST(ThreadPoolExecutorTest, CPUExpiration) {
+  expiration<CPUThreadPoolExecutor>();
+}
+
+TEST(ThreadPoolExecutorTest, IOExpiration) {
+  expiration<IOThreadPoolExecutor>();
+}