Polymorphic Functor implementation in Folly::FutureDAG
[folly.git] / folly / experimental / FutureDAG.h
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/futures/Future.h>
19 #include <folly/futures/SharedPromise.h>
20
21 namespace folly {
22
23 class FutureDAG : public std::enable_shared_from_this<FutureDAG> {
24  public:
25   static std::shared_ptr<FutureDAG> create() {
26     return std::shared_ptr<FutureDAG>(new FutureDAG());
27   }
28
29   typedef size_t Handle;
30   typedef std::function<Future<Unit>()> FutureFunc;
31
32   Handle add(FutureFunc func, Executor* executor = nullptr) {
33     nodes.emplace_back(std::move(func), executor);
34     return nodes.size() - 1;
35   }
36
37   void remove(Handle a) {
38     if (nodes.size() > a && nodes[a].hasDependents) {
39       for (auto& node : nodes) {
40         auto& deps = node.dependencies;
41         deps.erase(
42             std::remove(std::begin(deps), std::end(deps), a), std::end(deps));
43         for (Handle& handle : deps) {
44           if (handle > a) {
45             handle--;
46           }
47         }
48       }
49     }
50     nodes.erase(nodes.begin() + a);
51   }
52
53   void reset() {
54     // Delete all but source node, and reset dependency properties
55     Handle source_node;
56     std::unordered_set<Handle> memo;
57     for (auto& node : nodes) {
58       for (Handle handle : node.dependencies) {
59         memo.insert(handle);
60       }
61     }
62     for (Handle handle = 0; handle < nodes.size(); handle++) {
63       if (memo.find(handle) == memo.end()) {
64         source_node = handle;
65       }
66     }
67
68     nodes.erase(nodes.begin(), nodes.begin() + source_node);
69     nodes.erase(nodes.begin() + 1, nodes.end());
70     nodes[0].hasDependents = false;
71     nodes[0].dependencies.clear();
72   }
73
74   void dependency(Handle a, Handle b) {
75     nodes[b].dependencies.push_back(a);
76     nodes[a].hasDependents = true;
77   }
78
79   void clean_state(Handle source, Handle sink) {
80     for (auto handle : nodes[sink].dependencies) {
81       nodes[handle].hasDependents = false;
82     }
83     nodes[0].hasDependents = false;
84     remove(source);
85     remove(sink);
86   }
87
88   Future<Unit> go() {
89     if (hasCycle()) {
90       return makeFuture<Unit>(std::runtime_error("Cycle in FutureDAG graph"));
91     }
92     std::vector<Handle> rootNodes;
93     std::vector<Handle> leafNodes;
94     for (Handle handle = 0; handle < nodes.size(); handle++) {
95       if (nodes[handle].dependencies.empty()) {
96         rootNodes.push_back(handle);
97       }
98       if (!nodes[handle].hasDependents) {
99         leafNodes.push_back(handle);
100       }
101     }
102
103     auto sinkHandle = add([] { return Future<Unit>(); });
104     for (auto handle : leafNodes) {
105       dependency(handle, sinkHandle);
106     }
107
108     auto sourceHandle = add(nullptr);
109     for (auto handle : rootNodes) {
110       dependency(sourceHandle, handle);
111     }
112
113     for (Handle handle = 0; handle < nodes.size() - 1; handle++) {
114       std::vector<Future<Unit>> dependencies;
115       for (auto depHandle : nodes[handle].dependencies) {
116         dependencies.push_back(nodes[depHandle].promise.getFuture());
117       }
118
119       collect(dependencies)
120           .via(nodes[handle].executor)
121           .then([this, handle] {
122             nodes[handle].func().then([this, handle](Try<Unit>&& t) {
123               nodes[handle].promise.setTry(std::move(t));
124             });
125           })
126           .onError([this, handle](exception_wrapper ew) {
127             nodes[handle].promise.setException(std::move(ew));
128           });
129     }
130
131     nodes[sourceHandle].promise.setValue();
132     auto that = shared_from_this();
133     return nodes[sinkHandle].promise.getFuture().ensure([that] {}).then(
134         [this, sourceHandle, sinkHandle]() {
135           clean_state(sourceHandle, sinkHandle);
136         });
137   }
138
139  private:
140   FutureDAG() = default;
141
142   bool hasCycle() {
143     // Perform a modified topological sort to detect cycles
144     std::vector<std::vector<Handle>> dependencies;
145     for (auto& node : nodes) {
146       dependencies.push_back(node.dependencies);
147     }
148
149     std::vector<size_t> dependents(nodes.size());
150     for (auto& dependencyEdges : dependencies) {
151       for (auto handle : dependencyEdges) {
152         dependents[handle]++;
153       }
154     }
155
156     std::vector<Handle> handles;
157     for (Handle handle = 0; handle < nodes.size(); handle++) {
158       if (!nodes[handle].hasDependents) {
159         handles.push_back(handle);
160       }
161     }
162
163     while (!handles.empty()) {
164       auto handle = handles.back();
165       handles.pop_back();
166       while (!dependencies[handle].empty()) {
167         auto dependency = dependencies[handle].back();
168         dependencies[handle].pop_back();
169         if (--dependents[dependency] == 0) {
170           handles.push_back(dependency);
171         }
172       }
173     }
174
175     for (auto& dependencyEdges : dependencies) {
176       if (!dependencyEdges.empty()) {
177         return true;
178       }
179     }
180
181     return false;
182   }
183
184   struct Node {
185     Node(FutureFunc&& funcArg, Executor* executorArg)
186         : func(std::move(funcArg)), executor(executorArg) {}
187
188     FutureFunc func{nullptr};
189     Executor* executor{nullptr};
190     SharedPromise<Unit> promise;
191     std::vector<Handle> dependencies;
192     bool hasDependents{false};
193     bool visited{false};
194   };
195
196   std::vector<Node> nodes;
197 };
198
199 // Polymorphic functor implementation
200 template <typename T>
201 class FutureDAGFunctor {
202  public:
203   std::shared_ptr<FutureDAG> dag = FutureDAG::create();
204   T state;
205   std::vector<T> dep_states;
206   T result() {
207     return state;
208   };
209   // execReset() runs DAG & clears all nodes except for source
210   void execReset() {
211     this->dag->go().get();
212     this->dag->reset();
213   };
214   void exec() {
215     this->dag->go().get();
216   };
217   virtual void operator()(){};
218   explicit FutureDAGFunctor(T init_val) : state(init_val) {}
219   FutureDAGFunctor() : state() {}
220   virtual ~FutureDAGFunctor(){};
221 };
222
223 } // folly