66b30042e5a310c81b98b0709bfaf876b748b658
[folly.git] / folly / fibers / WhenN-inl.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/Optional.h>
17
18 #include <folly/fibers/FiberManagerInternal.h>
19 #include <folly/fibers/ForEach.h>
20
21 namespace folly {
22 namespace fibers {
23
24 template <class InputIterator>
25 typename std::vector<typename std::enable_if<
26     !std::is_same<
27         typename std::result_of<
28             typename std::iterator_traits<InputIterator>::value_type()>::type,
29         void>::value,
30     typename std::pair<
31         size_t,
32         typename std::result_of<typename std::iterator_traits<
33             InputIterator>::value_type()>::type>>::type>
34 collectN(InputIterator first, InputIterator last, size_t n) {
35   typedef typename std::result_of<
36       typename std::iterator_traits<InputIterator>::value_type()>::type Result;
37   assert(n > 0);
38   assert(std::distance(first, last) >= 0);
39   assert(n <= static_cast<size_t>(std::distance(first, last)));
40
41   struct Context {
42     std::vector<std::pair<size_t, Result>> results;
43     size_t tasksTodo;
44     std::exception_ptr e;
45     folly::Optional<Promise<void>> promise;
46
47     Context(size_t tasksTodo_) : tasksTodo(tasksTodo_) {
48       this->results.reserve(tasksTodo_);
49     }
50   };
51   auto context = std::make_shared<Context>(n);
52
53   await([first, last, context](Promise<void> promise) mutable {
54     context->promise = std::move(promise);
55     for (size_t i = 0; first != last; ++i, ++first) {
56       addTask([ i, context, f = std::move(*first) ]() {
57         try {
58           auto result = f();
59           if (context->tasksTodo == 0) {
60             return;
61           }
62           context->results.emplace_back(i, std::move(result));
63         } catch (...) {
64           if (context->tasksTodo == 0) {
65             return;
66           }
67           context->e = std::current_exception();
68         }
69         if (--context->tasksTodo == 0) {
70           context->promise->setValue();
71         }
72       });
73     }
74   });
75
76   if (context->e != std::exception_ptr()) {
77     std::rethrow_exception(context->e);
78   }
79
80   return std::move(context->results);
81 }
82
83 template <class InputIterator>
84 typename std::enable_if<
85     std::is_same<
86         typename std::result_of<
87             typename std::iterator_traits<InputIterator>::value_type()>::type,
88         void>::value,
89     std::vector<size_t>>::type
90 collectN(InputIterator first, InputIterator last, size_t n) {
91   assert(n > 0);
92   assert(std::distance(first, last) >= 0);
93   assert(n <= static_cast<size_t>(std::distance(first, last)));
94
95   struct Context {
96     std::vector<size_t> taskIndices;
97     std::exception_ptr e;
98     size_t tasksTodo;
99     folly::Optional<Promise<void>> promise;
100
101     Context(size_t tasksTodo_) : tasksTodo(tasksTodo_) {
102       this->taskIndices.reserve(tasksTodo_);
103     }
104   };
105   auto context = std::make_shared<Context>(n);
106
107   await([first, last, context](Promise<void> promise) mutable {
108     context->promise = std::move(promise);
109     for (size_t i = 0; first != last; ++i, ++first) {
110       addTask([ i, context, f = std::move(*first) ]() {
111         try {
112           f();
113           if (context->tasksTodo == 0) {
114             return;
115           }
116           context->taskIndices.push_back(i);
117         } catch (...) {
118           if (context->tasksTodo == 0) {
119             return;
120           }
121           context->e = std::current_exception();
122         }
123         if (--context->tasksTodo == 0) {
124           context->promise->setValue();
125         }
126       });
127     }
128   });
129
130   if (context->e != std::exception_ptr()) {
131     std::rethrow_exception(context->e);
132   }
133
134   return context->taskIndices;
135 }
136
137 template <class InputIterator>
138 typename std::vector<
139     typename std::enable_if<
140         !std::is_same<
141             typename std::result_of<typename std::iterator_traits<
142                 InputIterator>::value_type()>::type,
143             void>::value,
144         typename std::result_of<
145             typename std::iterator_traits<InputIterator>::value_type()>::type>::
146         type> inline collectAll(InputIterator first, InputIterator last) {
147   typedef typename std::result_of<
148       typename std::iterator_traits<InputIterator>::value_type()>::type Result;
149   size_t n = size_t(std::distance(first, last));
150   std::vector<Result> results;
151   std::vector<size_t> order(n);
152   results.reserve(n);
153
154   forEach(first, last, [&results, &order](size_t id, Result result) {
155     order[id] = results.size();
156     results.emplace_back(std::move(result));
157   });
158   assert(results.size() == n);
159
160   std::vector<Result> orderedResults;
161   orderedResults.reserve(n);
162
163   for (size_t i = 0; i < n; ++i) {
164     orderedResults.emplace_back(std::move(results[order[i]]));
165   }
166
167   return orderedResults;
168 }
169
170 template <class InputIterator>
171 typename std::enable_if<
172     std::is_same<
173         typename std::result_of<
174             typename std::iterator_traits<InputIterator>::value_type()>::type,
175         void>::value,
176     void>::type inline collectAll(InputIterator first, InputIterator last) {
177   forEach(first, last, [](size_t /* id */) {});
178 }
179
180 template <class InputIterator>
181 typename std::enable_if<
182     !std::is_same<
183         typename std::result_of<
184             typename std::iterator_traits<InputIterator>::value_type()>::type,
185         void>::value,
186     typename std::pair<
187         size_t,
188         typename std::result_of<typename std::iterator_traits<
189             InputIterator>::value_type()>::type>>::
190     type inline collectAny(InputIterator first, InputIterator last) {
191   auto result = collectN(first, last, 1);
192   assert(result.size() == 1);
193   return std::move(result[0]);
194 }
195
196 template <class InputIterator>
197 typename std::enable_if<
198     std::is_same<
199         typename std::result_of<
200             typename std::iterator_traits<InputIterator>::value_type()>::type,
201         void>::value,
202     size_t>::type inline collectAny(InputIterator first, InputIterator last) {
203   auto result = collectN(first, last, 1);
204   assert(result.size() == 1);
205   return std::move(result[0]);
206 }
207 }
208 }