Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
EventCount.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2016 Dmitry Vyukov <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H
11#define EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18// EventCount allows to wait for arbitrary predicates in non-blocking
19// algorithms. Think of condition variable, but wait predicate does not need to
20// be protected by a mutex. Usage:
21// Waiting thread does:
22//
23// if (predicate)
24// return act();
25// EventCount::Waiter& w = waiters[my_index];
26// ec.Prewait(&w);
27// if (predicate) {
28// ec.CancelWait(&w);
29// return act();
30// }
31// ec.CommitWait(&w);
32//
33// Notifying thread does:
34//
35// predicate = true;
36// ec.Notify(true);
37//
38// Notify is cheap if there are no waiting threads. Prewait/CommitWait are not
39// cheap, but they are executed only if the preceding predicate check has
40// failed.
41//
42// Algorithm outline:
43// There are two main variables: predicate (managed by user) and state_.
44// Operation closely resembles Dekker mutual algorithm:
45// https://en.wikipedia.org/wiki/Dekker%27s_algorithm
46// Waiting thread sets state_ then checks predicate, Notifying thread sets
47// predicate then checks state_. Due to seq_cst fences in between these
48// operations it is guaranteed than either waiter will see predicate change
49// and won't block, or notifying thread will see state_ change and will unblock
50// the waiter, or both. But it can't happen that both threads don't see each
51// other changes, which would lead to deadlock.
52class EventCount {
53 public:
54 class Waiter;
55
56 EventCount(MaxSizeVector<Waiter>& waiters) : state_(kStackMask), waiters_(waiters) {
57 eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
58 }
59
60 ~EventCount() {
61 // Ensure there are no waiters.
62 eigen_plain_assert(state_.load() == kStackMask);
63 }
64
65 // Prewait prepares for waiting.
66 // After calling Prewait, the thread must re-check the wait predicate
67 // and then call either CancelWait or CommitWait.
68 void Prewait() {
69 uint64_t state = state_.load(std::memory_order_relaxed);
70 for (;;) {
71 CheckState(state);
72 uint64_t newstate = state + kWaiterInc;
73 CheckState(newstate);
74 if (state_.compare_exchange_weak(state, newstate, std::memory_order_seq_cst)) return;
75 }
76 }
77
78 // CommitWait commits waiting after Prewait.
79 void CommitWait(Waiter* w) {
80 eigen_plain_assert((w->epoch & ~kEpochMask) == 0);
81 w->state = Waiter::kNotSignaled;
82 const uint64_t me = (w - &waiters_[0]) | w->epoch;
83 uint64_t state = state_.load(std::memory_order_seq_cst);
84 for (;;) {
85 CheckState(state, true);
86 uint64_t newstate;
87 if ((state & kSignalMask) != 0) {
88 // Consume the signal and return immediately.
89 newstate = state - kWaiterInc - kSignalInc;
90 } else {
91 // Remove this thread from pre-wait counter and add to the waiter stack.
92 newstate = ((state & kWaiterMask) - kWaiterInc) | me;
93 w->next.store(state & (kStackMask | kEpochMask), std::memory_order_relaxed);
94 }
95 CheckState(newstate);
96 if (state_.compare_exchange_weak(state, newstate, std::memory_order_acq_rel)) {
97 if ((state & kSignalMask) == 0) {
98 w->epoch += kEpochInc;
99 Park(w);
100 }
101 return;
102 }
103 }
104 }
105
106 // CancelWait cancels effects of the previous Prewait call.
107 void CancelWait() {
108 uint64_t state = state_.load(std::memory_order_relaxed);
109 for (;;) {
110 CheckState(state, true);
111 uint64_t newstate = state - kWaiterInc;
112 // We don't know if the thread was also notified or not,
113 // so we should not consume a signal unconditionally.
114 // Only if number of waiters is equal to number of signals,
115 // we know that the thread was notified and we must take away the signal.
116 if (((state & kWaiterMask) >> kWaiterShift) == ((state & kSignalMask) >> kSignalShift)) newstate -= kSignalInc;
117 CheckState(newstate);
118 if (state_.compare_exchange_weak(state, newstate, std::memory_order_acq_rel)) return;
119 }
120 }
121
122 // Notify wakes one or all waiting threads.
123 // Must be called after changing the associated wait predicate.
124 void Notify(bool notifyAll) {
125 std::atomic_thread_fence(std::memory_order_seq_cst);
126 uint64_t state = state_.load(std::memory_order_acquire);
127 for (;;) {
128 CheckState(state);
129 const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
130 const uint64_t signals = (state & kSignalMask) >> kSignalShift;
131 // Easy case: no waiters.
132 if ((state & kStackMask) == kStackMask && waiters == signals) return;
133 uint64_t newstate;
134 if (notifyAll) {
135 // Empty wait stack and set signal to number of pre-wait threads.
136 newstate = (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask;
137 } else if (signals < waiters) {
138 // There is a thread in pre-wait state, unblock it.
139 newstate = state + kSignalInc;
140 } else {
141 // Pop a waiter from list and unpark it.
142 Waiter* w = &waiters_[state & kStackMask];
143 uint64_t next = w->next.load(std::memory_order_relaxed);
144 newstate = (state & (kWaiterMask | kSignalMask)) | next;
145 }
146 CheckState(newstate);
147 if (state_.compare_exchange_weak(state, newstate, std::memory_order_acq_rel)) {
148 if (!notifyAll && (signals < waiters)) return; // unblocked pre-wait thread
149 if ((state & kStackMask) == kStackMask) return;
150 Waiter* w = &waiters_[state & kStackMask];
151 if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed);
152 Unpark(w);
153 return;
154 }
155 }
156 }
157
158 class Waiter {
159 friend class EventCount;
160 // Align to 128 byte boundary to prevent false sharing with other Waiter
161 // objects in the same vector.
162 EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next;
163 EIGEN_MUTEX mu;
164 EIGEN_CONDVAR cv;
165 uint64_t epoch = 0;
166 unsigned state = kNotSignaled;
167 enum {
168 kNotSignaled,
169 kWaiting,
170 kSignaled,
171 };
172 };
173
174 private:
175 // State_ layout:
176 // - low kWaiterBits is a stack of waiters committed wait
177 // (indexes in waiters_ array are used as stack elements,
178 // kStackMask means empty stack).
179 // - next kWaiterBits is count of waiters in prewait state.
180 // - next kWaiterBits is count of pending signals.
181 // - remaining bits are ABA counter for the stack.
182 // (stored in Waiter node and incremented on push).
183 static const uint64_t kWaiterBits = 14;
184 static const uint64_t kStackMask = (1ull << kWaiterBits) - 1;
185 static const uint64_t kWaiterShift = kWaiterBits;
186 static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1) << kWaiterShift;
187 static const uint64_t kWaiterInc = 1ull << kWaiterShift;
188 static const uint64_t kSignalShift = 2 * kWaiterBits;
189 static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1) << kSignalShift;
190 static const uint64_t kSignalInc = 1ull << kSignalShift;
191 static const uint64_t kEpochShift = 3 * kWaiterBits;
192 static const uint64_t kEpochBits = 64 - kEpochShift;
193 static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
194 static const uint64_t kEpochInc = 1ull << kEpochShift;
195 std::atomic<uint64_t> state_;
196 MaxSizeVector<Waiter>& waiters_;
197
198 static void CheckState(uint64_t state, bool waiter = false) {
199 static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem");
200 const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
201 const uint64_t signals = (state & kSignalMask) >> kSignalShift;
202 eigen_plain_assert(waiters >= signals);
203 eigen_plain_assert(waiters < (1 << kWaiterBits) - 1);
204 eigen_plain_assert(!waiter || waiters > 0);
205 (void)waiters;
206 (void)signals;
207 }
208
209 void Park(Waiter* w) {
210 EIGEN_MUTEX_LOCK lock(w->mu);
211 while (w->state != Waiter::kSignaled) {
212 w->state = Waiter::kWaiting;
213 w->cv.wait(lock);
214 }
215 }
216
217 void Unpark(Waiter* w) {
218 for (Waiter* next; w; w = next) {
219 uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask;
220 next = wnext == kStackMask ? nullptr : &waiters_[internal::convert_index<size_t>(wnext)];
221 unsigned state;
222 {
223 EIGEN_MUTEX_LOCK lock(w->mu);
224 state = w->state;
225 w->state = Waiter::kSignaled;
226 }
227 // Avoid notifying if it wasn't waiting.
228 if (state == Waiter::kWaiting) w->cv.notify_one();
229 }
230 }
231
232 EventCount(const EventCount&) = delete;
233 void operator=(const EventCount&) = delete;
234};
235
236} // namespace Eigen
237
238#endif // EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H
Namespace containing all symbols from the Eigen library.
Definition Core:137