libcoro  1.0
Coroutine support library for C++20
thread_pool.h
1 #pragma once
2 #include "function.h"
3 #include "prepared_coro.h"
4 #include "exceptions.h"
5 #include "future.h"
6 #include <condition_variable>
7 #include <mutex>
8 #include <queue>
9 #include <thread>
10 #include <vector>
11 
12 namespace coro {
13 
15 
26 template<typename CondVar>
28 public:
29 
30 
32 
37  thread_pool_t(unsigned int threads):_to_start(threads) {
38  _thr.reserve(threads);
39  }
40 
41  template<std::convertible_to<CondVar> CondVarInit>
42  thread_pool_t(unsigned int threads, CondVarInit &&cinit)
43  :_to_start(threads),_cond(std::forward<CondVarInit>(cinit)) {
44 
45  }
46 
48 
53  template<std::invocable<> Fn>
54  bool enqueue(Fn &&fn) {
55  std::unique_lock lk(_mx);
56  if (_stop) return false;
57  _que.emplace(std::forward<Fn>(fn));
58  ++_enqueued;
59  notify(lk);
60  return true;
61  }
62 
63 
65  template<std::invocable<> Fn>
66  bool operator>>(Fn &&fn) {
67  return enqueue(std::forward<Fn>(fn));
68  }
69 
70 
72 
79  template<typename Fn, typename ... Args>
80  requires std::invocable<Fn, Args...>
81  auto run(Fn &&fn, Args && ... args) -> future<std::invoke_result_t<Fn, Args...> > {
82  return [&](auto promise) {
83  enqueue([promise = std::move(promise),
84  fn = std::move(fn),
85  args = std::make_tuple(std::forward<Args>(args)...)]() mutable{
86  try {
87  if constexpr(std::is_void_v<std::invoke_result_t<Fn, Args...> >) {
88  std::apply(fn, std::move(args));
89  promise();
90  } else {
91  promise([&]{return std::apply(fn, std::move(args));});
92  }
93  } catch (...) {
94  promise.reject();
95  }
96  });
97  };
98 
99  }
100 
102 
105  void stop() {
106  std::vector<std::thread> tmp;
107  {
108  std::lock_guard lk(_mx);
109  if (_stop) return;
110  _stop = true;
111  _to_start = 0;
112  std::swap(_thr, tmp);
113  }
114  _cond.notify_all();
115  auto this_thr = std::this_thread::get_id();
116  for (auto &t: tmp) {
117  if (this_thr == t.get_id()) {
118  _current = nullptr;
119  t.detach();
120  } else {
121  t.join();
122  }
123  }
124  _que = {};
125 
126  }
127 
130  stop();
131  }
132 
134  static constexpr bool await_ready() noexcept {return false;}
136 
139  void await_resume() {
140  if (is_stopped()) throw await_canceled_exception();
141  }
143 
148  void await_suspend(std::coroutine_handle<> h) {
149  prepared_coro prep(h);
150  if (!enqueue(std::move(prep))) {
151  prep.release();
152  throw await_canceled_exception();
153  }
154  }
155 
157 
163  return [&](auto promise) {
164  std::lock_guard lk(_mx);
165  if (_enqueued == _finished) {
166  promise();
167  } else {
168  _joins.push_back({_enqueued, std::move(promise)});
169  std::push_heap(_joins.begin(), _joins.end());
170  }
171  };
172  }
173 
175  bool is_stopped() const {
176  std::lock_guard lk(_mx);
177  return _stop;
178  }
180  static thread_pool_t *current() {return _current;}
181 
182 protected:
183 
184 
185  std::vector<std::thread> _thr;
186  mutable std::mutex _mx;
187  CondVar _cond;
188  std::queue<function<void()> > _que;
189  long _finished = 0;
190  long _enqueued = 0;
191  unsigned int _to_start = 0;
192  bool _stop = false;
193 
194  struct join_info {
195  long _target;
196  promise<void> _prom;
197  int operator<=>(const join_info &other) const {
198  return other._target - _target;
199  }
200  };
201 
202  std::vector<join_info> _joins;
203 
204  static thread_local thread_pool_t *_current;
205 
206  void check_join(std::unique_lock<std::mutex> &lk) {
207  promise<void> joined;
208  lk.lock();
209  ++_finished;
210  if (_joins.empty() || _joins.front()._target > _finished) return;
211  do {
212  joined += _joins.front()._prom;
213  std::pop_heap(_joins.begin(), _joins.end());
214  _joins.pop_back();
215  } while (!_joins.empty() && _joins.front()._target <= _finished);
216  lk.unlock();
217  joined();
218  lk.lock();
219  }
220 
221  void notify(std::unique_lock<std::mutex> &mx) {
222  if (_to_start) {
223  _thr.emplace_back([](thread_pool_t *me){me->worker();}, this);
224  --_to_start;
225  mx.unlock();
226  } else {
227  mx.unlock();
228  _cond.notify_one();
229  }
230  }
231 
232  void worker() {
233  _current = this;
234  std::unique_lock lk(_mx);
235  while (!_stop) {
236  if (_que.empty()) {
237  _cond.wait(lk);
238  } else {
239  {
240  auto fn = std::move(_que.front());
241  _que.pop();
242  lk.unlock();
243  fn();
244  if (_current != this) return;
245  }
246  check_join(lk);
247 
248  }
249  }
250  }
251 
252 };
253 
254 template<typename CondVar>
255 inline thread_local thread_pool_t<CondVar> *thread_pool_t<CondVar>::_current = nullptr;
256 
257 
259 
267 
268 }
269 
270 
271 
272 
Exception is thrown on attempt to retrieve value after promise has been broken.
Definition: exceptions.h:10
Contains future value of T, can be co_awaited in coroutine.
Definition: future.h:417
std::coroutine_handle release()
release handle
Definition: prepared_coro.h:43
contains prepared coroutine (prepared to run)
Definition: prepared_coro.h:15
notify reject(std::exception_ptr e)
reject the future with exception
Definition: future.h:245
Carries reference to future<T>, callable, sets value of an associated future<T>
Definition: future.h:73
bool enqueue(Fn &&fn)
enqueue function
Definition: thread_pool.h:54
static constexpr bool await_ready() noexcept
co_await support (never ready)
Definition: thread_pool.h:134
future< void > join()
wait to process all enqueued tasks
Definition: thread_pool.h:162
void await_suspend(std::coroutine_handle<> h)
co_await support - resumes the coroutine inside of thread_pool
Definition: thread_pool.h:148
static thread_pool_t * current()
returns current thread pool (in context of managed thread)
Definition: thread_pool.h:180
bool operator>>(Fn &&fn)
alias to enqueue
Definition: thread_pool.h:66
thread_pool_t(unsigned int threads)
construct thread pool
Definition: thread_pool.h:37
void await_resume()
co_await support (nothing returned)
Definition: thread_pool.h:139
void stop()
stop thread pool
Definition: thread_pool.h:105
requires std::invocable< Fn, Args... > auto run(Fn &&fn, Args &&... args) -> future< std::invoke_result_t< Fn, Args... > >
Run a function in the thread_pool.
Definition: thread_pool.h:81
bool is_stopped() const
test whether is stopped
Definition: thread_pool.h:175
thread pool implementation
Definition: thread_pool.h:27
main namespace
Definition: aggregator.h:8