10#ifndef EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
11#define EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
14#include "./InternalHeaderCheck.h"
18template <
typename Environment>
19class ThreadPoolTempl :
public Eigen::ThreadPoolInterface {
21 typedef typename Environment::Task Task;
22 typedef RunQueue<Task, 1024> Queue;
24 ThreadPoolTempl(
int num_threads, Environment env = Environment()) : ThreadPoolTempl(num_threads, true, env) {}
26 ThreadPoolTempl(
int num_threads,
bool allow_spinning, Environment env = Environment())
28 num_threads_(num_threads),
29 allow_spinning_(allow_spinning),
30 thread_data_(num_threads),
31 all_coprimes_(num_threads),
32 waiters_(num_threads),
33 global_steal_partition_(EncodePartition(0, num_threads_)),
39 waiters_.resize(num_threads_);
47 eigen_plain_assert(num_threads_ < kMaxThreads);
48 for (
int i = 1; i <= num_threads_; ++i) {
49 all_coprimes_.emplace_back(i);
50 ComputeCoprimes(i, &all_coprimes_.back());
52#ifndef EIGEN_THREAD_LOCAL
53 init_barrier_.reset(
new Barrier(num_threads_));
55 thread_data_.resize(num_threads_);
56 for (
int i = 0; i < num_threads_; i++) {
57 SetStealPartition(i, EncodePartition(0, num_threads_));
58 thread_data_[i].thread.reset(env_.CreateThread([
this, i]() { WorkerLoop(i); }));
60#ifndef EIGEN_THREAD_LOCAL
63 init_barrier_->Wait();
78 for (
size_t i = 0; i < thread_data_.size(); i++) {
79 thread_data_[i].queue.Flush();
84 for (
size_t i = 0; i < thread_data_.size(); ++i) thread_data_[i].thread.reset();
87 void SetStealPartitions(
const std::vector<std::pair<unsigned, unsigned>>& partitions) {
88 eigen_plain_assert(partitions.size() ==
static_cast<std::size_t
>(num_threads_));
91 for (
int i = 0; i < num_threads_; i++) {
92 const auto& pair = partitions[i];
93 unsigned start = pair.first, end = pair.second;
94 AssertBounds(start, end);
95 unsigned val = EncodePartition(start, end);
96 SetStealPartition(i, val);
100 void Schedule(std::function<
void()> fn) EIGEN_OVERRIDE { ScheduleWithHint(std::move(fn), 0, num_threads_); }
102 void ScheduleWithHint(std::function<
void()> fn,
int start,
int limit)
override {
103 Task t = env_.CreateTask(std::move(fn));
104 PerThread* pt = GetPerThread();
105 if (pt->pool ==
this) {
107 Queue& q = thread_data_[pt->thread_id].queue;
108 t = q.PushFront(std::move(t));
112 eigen_plain_assert(start < limit);
113 eigen_plain_assert(limit <= num_threads_);
114 int num_queues = limit - start;
115 int rnd = Rand(&pt->rand) % num_queues;
116 eigen_plain_assert(start + rnd < limit);
117 Queue& q = thread_data_[start + rnd].queue;
118 t = q.PushBack(std::move(t));
134 void Cancel() EIGEN_OVERRIDE {
139#ifdef EIGEN_THREAD_ENV_SUPPORTS_CANCELLATION
140 for (
size_t i = 0; i < thread_data_.size(); i++) {
141 thread_data_[i].thread->OnCancel();
149 int NumThreads() const EIGEN_FINAL {
return num_threads_; }
151 int CurrentThreadId() const EIGEN_FINAL {
152 const PerThread* pt =
const_cast<ThreadPoolTempl*
>(
this)->GetPerThread();
153 if (pt->pool ==
this) {
154 return pt->thread_id;
168 static const int kMaxPartitionBits = 16;
169 static const int kMaxThreads = 1 << kMaxPartitionBits;
171 inline unsigned EncodePartition(
unsigned start,
unsigned limit) {
return (start << kMaxPartitionBits) | limit; }
173 inline void DecodePartition(
unsigned val,
unsigned* start,
unsigned* limit) {
174 *limit = val & (kMaxThreads - 1);
175 val >>= kMaxPartitionBits;
179 void AssertBounds(
int start,
int end) {
180 eigen_plain_assert(start >= 0);
181 eigen_plain_assert(start < end);
182 eigen_plain_assert(end <= num_threads_);
185 inline void SetStealPartition(
size_t i,
unsigned val) {
186 thread_data_[i].steal_partition.store(val, std::memory_order_relaxed);
189 inline unsigned GetStealPartition(
int i) {
return thread_data_[i].steal_partition.load(std::memory_order_relaxed); }
191 void ComputeCoprimes(
int N, MaxSizeVector<unsigned>* coprimes) {
192 for (
int i = 1; i <= N; i++) {
202 coprimes->push_back(i);
207 typedef typename Environment::EnvThread Thread;
210 constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) {}
211 ThreadPoolTempl* pool;
214#ifndef EIGEN_THREAD_LOCAL
221 constexpr ThreadData() : thread(), steal_partition(0), queue() {}
222 std::unique_ptr<Thread> thread;
223 std::atomic<unsigned> steal_partition;
228 const int num_threads_;
229 const bool allow_spinning_;
230 MaxSizeVector<ThreadData> thread_data_;
231 MaxSizeVector<MaxSizeVector<unsigned>> all_coprimes_;
232 MaxSizeVector<EventCount::Waiter> waiters_;
233 unsigned global_steal_partition_;
234 std::atomic<unsigned> blocked_;
235 std::atomic<bool> spinning_;
236 std::atomic<bool> done_;
237 std::atomic<bool> cancelled_;
239#ifndef EIGEN_THREAD_LOCAL
240 std::unique_ptr<Barrier> init_barrier_;
241 EIGEN_MUTEX per_thread_map_mutex_;
242 std::unordered_map<uint64_t, std::unique_ptr<PerThread>> per_thread_map_;
246 void WorkerLoop(
int thread_id) {
247#ifndef EIGEN_THREAD_LOCAL
248 std::unique_ptr<PerThread> new_pt(
new PerThread());
249 per_thread_map_mutex_.lock();
250 bool insertOK = per_thread_map_.emplace(GlobalThreadIdHash(), std::move(new_pt)).second;
251 eigen_plain_assert(insertOK);
252 EIGEN_UNUSED_VARIABLE(insertOK);
253 per_thread_map_mutex_.unlock();
254 init_barrier_->Notify();
255 init_barrier_->Wait();
257 PerThread* pt = GetPerThread();
259 pt->rand = GlobalThreadIdHash();
260 pt->thread_id = thread_id;
261 Queue& q = thread_data_[thread_id].queue;
262 EventCount::Waiter* waiter = &waiters_[thread_id];
267 const int spin_count = allow_spinning_ && num_threads_ > 0 ? 5000 / num_threads_ : 0;
268 if (num_threads_ == 1) {
275 while (!cancelled_) {
276 Task t = q.PopFront();
277 for (
int i = 0; i < spin_count && !t.f; i++) {
278 if (!cancelled_.load(std::memory_order_relaxed)) {
283 if (!WaitForWork(waiter, &t)) {
292 while (!cancelled_) {
293 Task t = q.PopFront();
300 if (allow_spinning_ && !spinning_ && !spinning_.exchange(
true)) {
301 for (
int i = 0; i < spin_count && !t.f; i++) {
302 if (!cancelled_.load(std::memory_order_relaxed)) {
311 if (!WaitForWork(waiter, &t)) {
327 Task Steal(
unsigned start,
unsigned limit) {
328 PerThread* pt = GetPerThread();
329 const size_t size = limit - start;
330 unsigned r = Rand(&pt->rand);
333 eigen_plain_assert(all_coprimes_[size - 1].size() < (1 << 30));
334 unsigned victim = ((uint64_t)r * (uint64_t)size) >> 32;
335 unsigned index = ((uint64_t)all_coprimes_[size - 1].size() * (uint64_t)r) >> 32;
336 unsigned inc = all_coprimes_[size - 1][index];
338 for (
unsigned i = 0; i < size; i++) {
339 eigen_plain_assert(start + victim < limit);
340 Task t = thread_data_[start + victim].queue.PopBack();
345 if (victim >= size) {
346 victim -=
static_cast<unsigned int>(size);
354 PerThread* pt = GetPerThread();
355 unsigned partition = GetStealPartition(pt->thread_id);
358 if (global_steal_partition_ == partition)
return Task();
359 unsigned start, limit;
360 DecodePartition(partition, &start, &limit);
361 AssertBounds(start, limit);
363 return Steal(start, limit);
367 Task GlobalSteal() {
return Steal(0, num_threads_); }
372 bool WaitForWork(EventCount::Waiter* waiter, Task* t) {
373 eigen_plain_assert(!t->f);
378 int victim = NonEmptyQueueIndex();
384 *t = thread_data_[victim].queue.PopBack();
393 if (done_ && blocked_ ==
static_cast<unsigned>(num_threads_)) {
400 if (NonEmptyQueueIndex() != -1) {
414 ec_.CommitWait(waiter);
419 int NonEmptyQueueIndex() {
420 PerThread* pt = GetPerThread();
424 const size_t size = thread_data_.size();
425 unsigned r = Rand(&pt->rand);
426 unsigned inc = all_coprimes_[size - 1][r % all_coprimes_[size - 1].size()];
427 unsigned victim = r % size;
428 for (
unsigned i = 0; i < size; i++) {
429 if (!thread_data_[victim].queue.Empty()) {
433 if (victim >= size) {
434 victim -=
static_cast<unsigned int>(size);
440 static EIGEN_STRONG_INLINE uint64_t GlobalThreadIdHash() {
441 return std::hash<std::thread::id>()(std::this_thread::get_id());
444 EIGEN_STRONG_INLINE PerThread* GetPerThread() {
445#ifndef EIGEN_THREAD_LOCAL
446 static PerThread dummy;
447 auto it = per_thread_map_.find(GlobalThreadIdHash());
448 if (it == per_thread_map_.end()) {
451 return it->second.get();
454 EIGEN_THREAD_LOCAL PerThread per_thread_;
455 PerThread* pt = &per_thread_;
460 static EIGEN_STRONG_INLINE
unsigned Rand(uint64_t* state) {
461 uint64_t current = *state;
463 *state = current * 6364136223846793005ULL + 0xda3e39cb94b95bdbULL;
465 return static_cast<unsigned>((current ^ (current >> 22)) >> (22 + (current >> 61)));
469typedef ThreadPoolTempl<StlThreadEnvironment> ThreadPool;
Namespace containing all symbols from the Eigen library.
Definition Core:137