Copyright 2014->2015
[folly.git] / folly / gen / ParallelMap-inl.h
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef FOLLY_GEN_PARALLELMAP_H
18 #error This file may only be included from folly/gen/ParallelMap.h
19 #endif
20
21 #include <atomic>
22 #include <cassert>
23 #include <thread>
24 #include <type_traits>
25 #include <utility>
26 #include <vector>
27
28 #include <folly/MPMCPipeline.h>
29 #include <folly/experimental/EventCount.h>
30
31 namespace folly { namespace gen { namespace detail {
32
33 /**
34  * PMap - Map in parallel (using threads). For producing a sequence of
35  * values by passing each value from a source collection through a
36  * predicate while running the predicate in parallel in different
37  * threads.
38  *
39  * This type is usually used through the 'pmap' helper function:
40  *
41  *   auto squares = seq(1, 10) | pmap(4, fibonacci) | sum;
42  */
43 template<class Predicate>
44 class PMap : public Operator<PMap<Predicate>> {
45   Predicate pred_;
46   size_t nThreads_;
47  public:
48   PMap() {}
49
50   PMap(Predicate pred, size_t nThreads)
51     : pred_(std::move(pred)),
52       nThreads_(nThreads) { }
53
54   template<class Value,
55            class Source,
56            class Input = typename std::decay<Value>::type,
57            class Output = typename std::decay<
58              typename std::result_of<Predicate(Value)>::type
59              >::type>
60   class Generator :
61     public GenImpl<Output, Generator<Value, Source, Input, Output>> {
62     Source source_;
63     Predicate pred_;
64     const size_t nThreads_;
65
66     class ExecutionPipeline {
67       std::vector<std::thread> workers_;
68       std::atomic<bool> done_{false};
69       const Predicate& pred_;
70       MPMCPipeline<Input, Output> pipeline_;
71       EventCount wake_;
72
73      public:
74       ExecutionPipeline(const Predicate& pred, size_t nThreads)
75         : pred_(pred),
76           pipeline_(nThreads, nThreads) {
77         workers_.reserve(nThreads);
78         for (size_t i = 0; i < nThreads; i++) {
79           workers_.push_back(std::thread([this] { this->predApplier(); }));
80         }
81       }
82
83       ~ExecutionPipeline() {
84         assert(pipeline_.sizeGuess() == 0);
85         assert(done_.load());
86         for (auto& w : workers_) { w.join(); }
87       }
88
89       void stop() {
90         // prevent workers from consuming more than we produce.
91         done_.store(true, std::memory_order_release);
92         wake_.notifyAll();
93       }
94
95       bool write(Value&& value) {
96         bool wrote = pipeline_.write(std::forward<Value>(value));
97         if (wrote) {
98           wake_.notify();
99         }
100         return wrote;
101       }
102
103       void blockingWrite(Value&& value) {
104         pipeline_.blockingWrite(std::forward<Value>(value));
105         wake_.notify();
106       }
107
108       bool read(Output& out) {
109         return pipeline_.read(out);
110       }
111
112       void blockingRead(Output& out) {
113         pipeline_.blockingRead(out);
114       }
115
116      private:
117       void predApplier() {
118         // Each thread takes a value from the pipeline_, runs the
119         // predicate and enqueues the result. The pipeline preserves
120         // ordering. NOTE: don't use blockingReadStage<0> to read from
121         // the pipeline_ as there may not be any: end-of-data is signaled
122         // separately using done_/wake_.
123         Input in;
124         for (;;) {
125           auto key = wake_.prepareWait();
126
127           typename MPMCPipeline<Input, Output>::template Ticket<0> ticket;
128           if (pipeline_.template readStage<0>(ticket, in)) {
129             wake_.cancelWait();
130             Output out = pred_(std::move(in));
131             pipeline_.template blockingWriteStage<0>(ticket,
132                                                      std::move(out));
133             continue;
134           }
135
136           if (done_.load(std::memory_order_acquire)) {
137             wake_.cancelWait();
138             break;
139           }
140
141           // Not done_, but no items in the queue.
142           wake_.wait(key);
143         }
144       }
145     };
146
147   public:
148     Generator(Source source, const Predicate& pred, size_t nThreads)
149       : source_(std::move(source)),
150         pred_(pred),
151         nThreads_(nThreads ?: sysconf(_SC_NPROCESSORS_ONLN)) {
152     }
153
154     template<class Body>
155     void foreach(Body&& body) const {
156       ExecutionPipeline pipeline(pred_, nThreads_);
157
158       size_t wrote = 0;
159       size_t read = 0;
160       source_.foreach([&](Value value) {
161         if (pipeline.write(std::forward<Value>(value))) {
162           // input queue not yet full, saturate it before we process
163           // anything downstream
164           ++wrote;
165           return;
166         }
167
168         // input queue full; drain ready items from the queue
169         Output out;
170         while (pipeline.read(out)) {
171           ++read;
172           body(std::move(out));
173         }
174
175         // write the value we were going to write before we made room.
176         pipeline.blockingWrite(std::forward<Value>(value));
177         ++wrote;
178       });
179
180       pipeline.stop();
181
182       // flush the output queue
183       while (read < wrote) {
184         Output out;
185         pipeline.blockingRead(out);
186         ++read;
187         body(std::move(out));
188       }
189     }
190
191     template<class Handler>
192     bool apply(Handler&& handler) const {
193       ExecutionPipeline pipeline(pred_, nThreads_);
194
195       size_t wrote = 0;
196       size_t read = 0;
197       bool more = true;
198       source_.apply([&](Value value) {
199         if (pipeline.write(std::forward<Value>(value))) {
200           // input queue not yet full, saturate it before we process
201           // anything downstream
202           ++wrote;
203           return true;
204         }
205
206         // input queue full; drain ready items from the queue
207         Output out;
208         while (pipeline.read(out)) {
209           ++read;
210           if (!handler(std::move(out))) {
211             more = false;
212             return false;
213           }
214         }
215
216         // write the value we were going to write before we made room.
217         pipeline.blockingWrite(std::forward<Value>(value));
218         ++wrote;
219         return true;
220       });
221
222       pipeline.stop();
223
224       // flush the output queue
225       while (read < wrote) {
226         Output out;
227         pipeline.blockingRead(out);
228         ++read;
229         if (more) {
230           more = more && handler(std::move(out));
231         }
232       }
233       return more;
234     }
235
236     static constexpr bool infinite = Source::infinite;
237   };
238
239   template<class Source,
240            class Value,
241            class Gen = Generator<Value, Source>>
242   Gen compose(GenImpl<Value, Source>&& source) const {
243     return Gen(std::move(source.self()), pred_, nThreads_);
244   }
245
246   template<class Source,
247            class Value,
248            class Gen = Generator<Value, Source>>
249   Gen compose(const GenImpl<Value, Source>& source) const {
250     return Gen(source.self(), pred_, nThreads_);
251   }
252 };
253
254 }}}  // namespaces