Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
Visitor.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008 Gael Guennebaud <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_VISITOR_H
11#define EIGEN_VISITOR_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20template <typename Visitor, typename Derived, int UnrollCount,
21 bool Vectorize = (Derived::PacketAccess && functor_traits<Visitor>::PacketAccess), bool LinearAccess = false,
22 bool ShortCircuitEvaluation = false>
23struct visitor_impl;
24
25template <typename Visitor, bool ShortCircuitEvaluation = false>
26struct short_circuit_eval_impl {
27 // if short circuit evaluation is not used, do nothing
28 static EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const Visitor&) { return false; }
29};
30template <typename Visitor>
31struct short_circuit_eval_impl<Visitor, true> {
32 // if short circuit evaluation is used, check the visitor
33 static EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const Visitor& visitor) {
34 return visitor.done();
35 }
36};
37
38// unrolled inner-outer traversal
39template <typename Visitor, typename Derived, int UnrollCount, bool Vectorize, bool ShortCircuitEvaluation>
40struct visitor_impl<Visitor, Derived, UnrollCount, Vectorize, false, ShortCircuitEvaluation> {
41 // don't use short circuit evaulation for unrolled version
42 using Scalar = typename Derived::Scalar;
43 using Packet = typename packet_traits<Scalar>::type;
44 static constexpr bool RowMajor = Derived::IsRowMajor;
45 static constexpr int RowsAtCompileTime = Derived::RowsAtCompileTime;
46 static constexpr int ColsAtCompileTime = Derived::ColsAtCompileTime;
47 static constexpr int PacketSize = packet_traits<Scalar>::size;
48
49 static constexpr bool CanVectorize(int K) {
50 constexpr int InnerSizeAtCompileTime = RowMajor ? ColsAtCompileTime : RowsAtCompileTime;
51 if (InnerSizeAtCompileTime < PacketSize) return false;
52 return Vectorize && (InnerSizeAtCompileTime - (K % InnerSizeAtCompileTime) >= PacketSize);
53 }
54
55 template <int K = 0, bool Empty = (K == UnrollCount), std::enable_if_t<Empty, bool> = true>
56 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived&, Visitor&) {}
57
58 template <int K = 0, bool Empty = (K == UnrollCount), bool Initialize = (K == 0), bool DoVectorOp = CanVectorize(K),
59 std::enable_if_t<!Empty && Initialize && !DoVectorOp, bool> = true>
60 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& mat, Visitor& visitor) {
61 visitor.init(mat.coeff(0, 0), 0, 0);
62 run<1>(mat, visitor);
63 }
64
65 template <int K = 0, bool Empty = (K == UnrollCount), bool Initialize = (K == 0), bool DoVectorOp = CanVectorize(K),
66 std::enable_if_t<!Empty && !Initialize && !DoVectorOp, bool> = true>
67 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& mat, Visitor& visitor) {
68 static constexpr int R = RowMajor ? (K / ColsAtCompileTime) : (K % RowsAtCompileTime);
69 static constexpr int C = RowMajor ? (K % ColsAtCompileTime) : (K / RowsAtCompileTime);
70 visitor(mat.coeff(R, C), R, C);
71 run<K + 1>(mat, visitor);
72 }
73
74 template <int K = 0, bool Empty = (K == UnrollCount), bool Initialize = (K == 0), bool DoVectorOp = CanVectorize(K),
75 std::enable_if_t<!Empty && Initialize && DoVectorOp, bool> = true>
76 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& mat, Visitor& visitor) {
77 Packet P = mat.template packet<Packet>(0, 0);
78 visitor.initpacket(P, 0, 0);
79 run<PacketSize>(mat, visitor);
80 }
81
82 template <int K = 0, bool Empty = (K == UnrollCount), bool Initialize = (K == 0), bool DoVectorOp = CanVectorize(K),
83 std::enable_if_t<!Empty && !Initialize && DoVectorOp, bool> = true>
84 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& mat, Visitor& visitor) {
85 static constexpr int R = RowMajor ? (K / ColsAtCompileTime) : (K % RowsAtCompileTime);
86 static constexpr int C = RowMajor ? (K % ColsAtCompileTime) : (K / RowsAtCompileTime);
87 Packet P = mat.template packet<Packet>(R, C);
88 visitor.packet(P, R, C);
89 run<K + PacketSize>(mat, visitor);
90 }
91};
92
93// unrolled linear traversal
94template <typename Visitor, typename Derived, int UnrollCount, bool Vectorize, bool ShortCircuitEvaluation>
95struct visitor_impl<Visitor, Derived, UnrollCount, Vectorize, true, ShortCircuitEvaluation> {
96 // don't use short circuit evaulation for unrolled version
97 using Scalar = typename Derived::Scalar;
98 using Packet = typename packet_traits<Scalar>::type;
99 static constexpr int PacketSize = packet_traits<Scalar>::size;
100
101 static constexpr bool CanVectorize(int K) { return Vectorize && ((UnrollCount - K) >= PacketSize); }
102
103 // empty
104 template <int K = 0, bool Empty = (K == UnrollCount), std::enable_if_t<Empty, bool> = true>
105 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived&, Visitor&) {}
106
107 // scalar initialization
108 template <int K = 0, bool Empty = (K == UnrollCount), bool Initialize = (K == 0), bool DoVectorOp = CanVectorize(K),
109 std::enable_if_t<!Empty && Initialize && !DoVectorOp, bool> = true>
110 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& mat, Visitor& visitor) {
111 visitor.init(mat.coeff(0), 0);
112 run<1>(mat, visitor);
113 }
114
115 // scalar iteration
116 template <int K = 0, bool Empty = (K == UnrollCount), bool Initialize = (K == 0), bool DoVectorOp = CanVectorize(K),
117 std::enable_if_t<!Empty && !Initialize && !DoVectorOp, bool> = true>
118 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& mat, Visitor& visitor) {
119 visitor(mat.coeff(K), K);
120 run<K + 1>(mat, visitor);
121 }
122
123 // vector initialization
124 template <int K = 0, bool Empty = (K == UnrollCount), bool Initialize = (K == 0), bool DoVectorOp = CanVectorize(K),
125 std::enable_if_t<!Empty && Initialize && DoVectorOp, bool> = true>
126 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& mat, Visitor& visitor) {
127 Packet P = mat.template packet<Packet>(0);
128 visitor.initpacket(P, 0);
129 run<PacketSize>(mat, visitor);
130 }
131
132 // vector iteration
133 template <int K = 0, bool Empty = (K == UnrollCount), bool Initialize = (K == 0), bool DoVectorOp = CanVectorize(K),
134 std::enable_if_t<!Empty && !Initialize && DoVectorOp, bool> = true>
135 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& mat, Visitor& visitor) {
136 Packet P = mat.template packet<Packet>(K);
137 visitor.packet(P, K);
138 run<K + PacketSize>(mat, visitor);
139 }
140};
141
142// dynamic scalar outer-inner traversal
143template <typename Visitor, typename Derived, bool ShortCircuitEvaluation>
144struct visitor_impl<Visitor, Derived, Dynamic, /*Vectorize=*/false, /*LinearAccess=*/false, ShortCircuitEvaluation> {
145 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
146 static constexpr bool RowMajor = Derived::IsRowMajor;
147
148 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& mat, Visitor& visitor) {
149 const Index innerSize = RowMajor ? mat.cols() : mat.rows();
150 const Index outerSize = RowMajor ? mat.rows() : mat.cols();
151 if (innerSize == 0 || outerSize == 0) return;
152 {
153 visitor.init(mat.coeff(0, 0), 0, 0);
154 if (short_circuit::run(visitor)) return;
155 for (Index i = 1; i < innerSize; ++i) {
156 Index r = RowMajor ? 0 : i;
157 Index c = RowMajor ? i : 0;
158 visitor(mat.coeff(r, c), r, c);
159 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor)) return;
160 }
161 }
162 for (Index j = 1; j < outerSize; j++) {
163 for (Index i = 0; i < innerSize; ++i) {
164 Index r = RowMajor ? j : i;
165 Index c = RowMajor ? i : j;
166 visitor(mat.coeff(r, c), r, c);
167 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor)) return;
168 }
169 }
170 }
171};
172
173// dynamic vectorized outer-inner traversal
174template <typename Visitor, typename Derived, bool ShortCircuitEvaluation>
175struct visitor_impl<Visitor, Derived, Dynamic, /*Vectorize=*/true, /*LinearAccess=*/false, ShortCircuitEvaluation> {
176 using Scalar = typename Derived::Scalar;
177 using Packet = typename packet_traits<Scalar>::type;
178 static constexpr int PacketSize = packet_traits<Scalar>::size;
179 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
180 static constexpr bool RowMajor = Derived::IsRowMajor;
181
182 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& mat, Visitor& visitor) {
183 const Index innerSize = RowMajor ? mat.cols() : mat.rows();
184 const Index outerSize = RowMajor ? mat.rows() : mat.cols();
185 if (innerSize == 0 || outerSize == 0) return;
186 {
187 Index i = 0;
188 if (innerSize < PacketSize) {
189 visitor.init(mat.coeff(0, 0), 0, 0);
190 i = 1;
191 } else {
192 Packet p = mat.template packet<Packet>(0, 0);
193 visitor.initpacket(p, 0, 0);
194 i = PacketSize;
195 }
196 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor)) return;
197 for (; i + PacketSize - 1 < innerSize; i += PacketSize) {
198 Index r = RowMajor ? 0 : i;
199 Index c = RowMajor ? i : 0;
200 Packet p = mat.template packet<Packet>(r, c);
201 visitor.packet(p, r, c);
202 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor)) return;
203 }
204 for (; i < innerSize; ++i) {
205 Index r = RowMajor ? 0 : i;
206 Index c = RowMajor ? i : 0;
207 visitor(mat.coeff(r, c), r, c);
208 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor)) return;
209 }
210 }
211 for (Index j = 1; j < outerSize; j++) {
212 Index i = 0;
213 for (; i + PacketSize - 1 < innerSize; i += PacketSize) {
214 Index r = RowMajor ? j : i;
215 Index c = RowMajor ? i : j;
216 Packet p = mat.template packet<Packet>(r, c);
217 visitor.packet(p, r, c);
218 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor)) return;
219 }
220 for (; i < innerSize; ++i) {
221 Index r = RowMajor ? j : i;
222 Index c = RowMajor ? i : j;
223 visitor(mat.coeff(r, c), r, c);
224 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor)) return;
225 }
226 }
227 }
228};
229
230// dynamic scalar linear traversal
231template <typename Visitor, typename Derived, bool ShortCircuitEvaluation>
232struct visitor_impl<Visitor, Derived, Dynamic, /*Vectorize=*/false, /*LinearAccess=*/true, ShortCircuitEvaluation> {
233 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
234
235 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& mat, Visitor& visitor) {
236 const Index size = mat.size();
237 if (size == 0) return;
238 visitor.init(mat.coeff(0), 0);
239 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor)) return;
240 for (Index k = 1; k < size; k++) {
241 visitor(mat.coeff(k), k);
242 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor)) return;
243 }
244 }
245};
246
247// dynamic vectorized linear traversal
248template <typename Visitor, typename Derived, bool ShortCircuitEvaluation>
249struct visitor_impl<Visitor, Derived, Dynamic, /*Vectorize=*/true, /*LinearAccess=*/true, ShortCircuitEvaluation> {
250 using Scalar = typename Derived::Scalar;
251 using Packet = typename packet_traits<Scalar>::type;
252 static constexpr int PacketSize = packet_traits<Scalar>::size;
253 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
254
255 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& mat, Visitor& visitor) {
256 const Index size = mat.size();
257 if (size == 0) return;
258 Index k = 0;
259 if (size < PacketSize) {
260 visitor.init(mat.coeff(0), 0);
261 k = 1;
262 } else {
263 Packet p = mat.template packet<Packet>(k);
264 visitor.initpacket(p, k);
265 k = PacketSize;
266 }
267 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor)) return;
268 for (; k + PacketSize - 1 < size; k += PacketSize) {
269 Packet p = mat.template packet<Packet>(k);
270 visitor.packet(p, k);
271 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor)) return;
272 }
273 for (; k < size; k++) {
274 visitor(mat.coeff(k), k);
275 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor)) return;
276 }
277 }
278};
279
280// evaluator adaptor
281template <typename XprType>
282class visitor_evaluator {
283 public:
284 typedef evaluator<XprType> Evaluator;
285 typedef typename XprType::Scalar Scalar;
286 using Packet = typename packet_traits<Scalar>::type;
287 typedef std::remove_const_t<typename XprType::CoeffReturnType> CoeffReturnType;
288
289 static constexpr bool PacketAccess = static_cast<bool>(Evaluator::Flags & PacketAccessBit);
290 static constexpr bool LinearAccess = static_cast<bool>(Evaluator::Flags & LinearAccessBit);
291 static constexpr bool IsRowMajor = static_cast<bool>(XprType::IsRowMajor);
292 static constexpr int RowsAtCompileTime = XprType::RowsAtCompileTime;
293 static constexpr int ColsAtCompileTime = XprType::ColsAtCompileTime;
294 static constexpr int XprAlignment = Evaluator::Alignment;
295 static constexpr int CoeffReadCost = Evaluator::CoeffReadCost;
296
297 EIGEN_DEVICE_FUNC explicit visitor_evaluator(const XprType& xpr) : m_evaluator(xpr), m_xpr(xpr) {}
298
299 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_xpr.rows(); }
300 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_xpr.cols(); }
301 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index size() const EIGEN_NOEXCEPT { return m_xpr.size(); }
302 // outer-inner access
303 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const {
304 return m_evaluator.coeff(row, col);
305 }
306 template <typename Packet, int Alignment = Unaligned>
307 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(Index row, Index col) const {
308 return m_evaluator.template packet<Alignment, Packet>(row, col);
309 }
310 // linear access
311 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_evaluator.coeff(index); }
312 template <typename Packet, int Alignment = XprAlignment>
313 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(Index index) const {
314 return m_evaluator.template packet<Alignment, Packet>(index);
315 }
316
317 protected:
318 Evaluator m_evaluator;
319 const XprType& m_xpr;
320};
321
322template <typename Derived, typename Visitor, bool ShortCircuitEvaulation>
323struct visit_impl {
324 using Evaluator = visitor_evaluator<Derived>;
325 using Scalar = typename DenseBase<Derived>::Scalar;
326
327 static constexpr bool IsRowMajor = DenseBase<Derived>::IsRowMajor;
328 static constexpr int SizeAtCompileTime = DenseBase<Derived>::SizeAtCompileTime;
329 static constexpr int RowsAtCompileTime = DenseBase<Derived>::RowsAtCompileTime;
330 static constexpr int ColsAtCompileTime = DenseBase<Derived>::ColsAtCompileTime;
331 static constexpr int InnerSizeAtCompileTime = IsRowMajor ? ColsAtCompileTime : RowsAtCompileTime;
332 static constexpr int OuterSizeAtCompileTime = IsRowMajor ? RowsAtCompileTime : ColsAtCompileTime;
333
334 static constexpr bool LinearAccess =
335 Evaluator::LinearAccess && static_cast<bool>(functor_traits<Visitor>::LinearAccess);
336 static constexpr bool Vectorize = Evaluator::PacketAccess && static_cast<bool>(functor_traits<Visitor>::PacketAccess);
337
338 static constexpr int PacketSize = packet_traits<Scalar>::size;
339 static constexpr int VectorOps =
340 Vectorize ? (LinearAccess ? (SizeAtCompileTime / PacketSize)
341 : (OuterSizeAtCompileTime * (InnerSizeAtCompileTime / PacketSize)))
342 : 0;
343 static constexpr int ScalarOps = SizeAtCompileTime - (VectorOps * PacketSize);
344 // treat vector op and scalar op as same cost for unroll logic
345 static constexpr int TotalOps = VectorOps + ScalarOps;
346
347 static constexpr int UnrollCost = int(Evaluator::CoeffReadCost) + int(functor_traits<Visitor>::Cost);
348 static constexpr bool Unroll = (SizeAtCompileTime != Dynamic) && ((TotalOps * UnrollCost) <= EIGEN_UNROLLING_LIMIT);
349 static constexpr int UnrollCount = Unroll ? int(SizeAtCompileTime) : Dynamic;
350
351 using impl = visitor_impl<Visitor, Evaluator, UnrollCount, Vectorize, LinearAccess, ShortCircuitEvaulation>;
352
353 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const DenseBase<Derived>& mat, Visitor& visitor) {
354 Evaluator evaluator(mat.derived());
355 impl::run(evaluator, visitor);
356 }
357};
358
359} // end namespace internal
360
380template <typename Derived>
381template <typename Visitor>
382EIGEN_DEVICE_FUNC void DenseBase<Derived>::visit(Visitor& visitor) const {
383 using impl = internal::visit_impl<Derived, Visitor, /*ShortCircuitEvaulation*/ false>;
384 impl::run(derived(), visitor);
385}
386
387namespace internal {
388
392template <typename Derived>
393struct coeff_visitor {
394 // default initialization to avoid countless invalid maybe-uninitialized warnings by gcc
395 EIGEN_DEVICE_FUNC coeff_visitor() : row(-1), col(-1), res(0) {}
396 typedef typename Derived::Scalar Scalar;
397 Index row, col;
398 Scalar res;
399 EIGEN_DEVICE_FUNC inline void init(const Scalar& value, Index i, Index j) {
400 res = value;
401 row = i;
402 col = j;
403 }
404};
405
406template <typename Scalar, int NaNPropagation, bool is_min = true>
407struct minmax_compare {
408 typedef typename packet_traits<Scalar>::type Packet;
409 static EIGEN_DEVICE_FUNC inline bool compare(Scalar a, Scalar b) { return a < b; }
410 static EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& p) { return predux_min<NaNPropagation>(p); }
411};
412
413template <typename Scalar, int NaNPropagation>
414struct minmax_compare<Scalar, NaNPropagation, false> {
415 typedef typename packet_traits<Scalar>::type Packet;
416 static EIGEN_DEVICE_FUNC inline bool compare(Scalar a, Scalar b) { return a > b; }
417 static EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& p) { return predux_max<NaNPropagation>(p); }
418};
419
420// Default implementation used by non-floating types, where we do not
421// need special logic for NaN handling.
422template <typename Derived, bool is_min, int NaNPropagation,
423 bool isInt = NumTraits<typename Derived::Scalar>::IsInteger>
424struct minmax_coeff_visitor : coeff_visitor<Derived> {
425 using Scalar = typename Derived::Scalar;
426 using Packet = typename packet_traits<Scalar>::type;
427 using Comparator = minmax_compare<Scalar, NaNPropagation, is_min>;
428 static constexpr Index PacketSize = packet_traits<Scalar>::size;
429
430 EIGEN_DEVICE_FUNC inline void operator()(const Scalar& value, Index i, Index j) {
431 if (Comparator::compare(value, this->res)) {
432 this->res = value;
433 this->row = i;
434 this->col = j;
435 }
436 }
437 EIGEN_DEVICE_FUNC inline void packet(const Packet& p, Index i, Index j) {
438 Scalar value = Comparator::predux(p);
439 if (Comparator::compare(value, this->res)) {
440 const Packet range = preverse(plset<Packet>(Scalar(1)));
441 Packet mask = pcmp_eq(pset1<Packet>(value), p);
442 Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
443 this->res = value;
444 this->row = Derived::IsRowMajor ? i : i + max_idx;
445 this->col = Derived::IsRowMajor ? j + max_idx : j;
446 }
447 }
448 EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index i, Index j) {
449 Scalar value = Comparator::predux(p);
450 const Packet range = preverse(plset<Packet>(Scalar(1)));
451 Packet mask = pcmp_eq(pset1<Packet>(value), p);
452 Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
453 this->res = value;
454 this->row = Derived::IsRowMajor ? i : i + max_idx;
455 this->col = Derived::IsRowMajor ? j + max_idx : j;
456 }
457};
458
459// Suppress NaN. The only case in which we return NaN is if the matrix is all NaN,
460// in which case, row=0, col=0 is returned for the location.
461template <typename Derived, bool is_min>
462struct minmax_coeff_visitor<Derived, is_min, PropagateNumbers, false> : coeff_visitor<Derived> {
463 typedef typename Derived::Scalar Scalar;
464 using Packet = typename packet_traits<Scalar>::type;
465 using Comparator = minmax_compare<Scalar, PropagateNumbers, is_min>;
466
467 EIGEN_DEVICE_FUNC inline void operator()(const Scalar& value, Index i, Index j) {
468 if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
469 this->res = value;
470 this->row = i;
471 this->col = j;
472 }
473 }
474 EIGEN_DEVICE_FUNC inline void packet(const Packet& p, Index i, Index j) {
475 const Index PacketSize = packet_traits<Scalar>::size;
476 Scalar value = Comparator::predux(p);
477 if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
478 const Packet range = preverse(plset<Packet>(Scalar(1)));
479 /* mask will be zero for NaNs, so they will be ignored. */
480 Packet mask = pcmp_eq(pset1<Packet>(value), p);
481 Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
482 this->res = value;
483 this->row = Derived::IsRowMajor ? i : i + max_idx;
484 this->col = Derived::IsRowMajor ? j + max_idx : j;
487 EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index i, Index j) {
488 const Index PacketSize = packet_traits<Scalar>::size;
489 Scalar value = Comparator::predux(p);
490 if ((numext::isnan)(value)) {
491 this->res = value;
492 this->row = 0;
493 this->col = 0;
494 return;
495 }
496 const Packet range = preverse(plset<Packet>(Scalar(1)));
497 /* mask will be zero for NaNs, so they will be ignored. */
498 Packet mask = pcmp_eq(pset1<Packet>(value), p);
499 Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
500 this->res = value;
501 this->row = Derived::IsRowMajor ? i : i + max_idx;
502 this->col = Derived::IsRowMajor ? j + max_idx : j;
503 }
504};
505
506// Propagate NaNs. If the matrix contains NaN, the location of the first NaN
507// will be returned in row and col.
508template <typename Derived, bool is_min, int NaNPropagation>
509struct minmax_coeff_visitor<Derived, is_min, NaNPropagation, false> : coeff_visitor<Derived> {
510 typedef typename Derived::Scalar Scalar;
511 using Packet = typename packet_traits<Scalar>::type;
512 using Comparator = minmax_compare<Scalar, PropagateNaN, is_min>;
513
514 EIGEN_DEVICE_FUNC inline void operator()(const Scalar& value, Index i, Index j) {
515 const bool value_is_nan = (numext::isnan)(value);
516 if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
517 this->res = value;
518 this->row = i;
519 this->col = j;
520 }
521 }
522 EIGEN_DEVICE_FUNC inline void packet(const Packet& p, Index i, Index j) {
523 const Index PacketSize = packet_traits<Scalar>::size;
524 Scalar value = Comparator::predux(p);
525 const bool value_is_nan = (numext::isnan)(value);
526 if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
527 const Packet range = preverse(plset<Packet>(Scalar(1)));
528 // If the value is NaN, pick the first position of a NaN, otherwise pick the first extremal value.
529 Packet mask = value_is_nan ? pnot(pcmp_eq(p, p)) : pcmp_eq(pset1<Packet>(value), p);
530 Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
531 this->res = value;
532 this->row = Derived::IsRowMajor ? i : i + max_idx;
533 this->col = Derived::IsRowMajor ? j + max_idx : j;
534 }
535 }
536 EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index i, Index j) {
537 const Index PacketSize = packet_traits<Scalar>::size;
538 Scalar value = Comparator::predux(p);
539 const bool value_is_nan = (numext::isnan)(value);
540 const Packet range = preverse(plset<Packet>(Scalar(1)));
541 // If the value is NaN, pick the first position of a NaN, otherwise pick the first extremal value.
542 Packet mask = value_is_nan ? pnot(pcmp_eq(p, p)) : pcmp_eq(pset1<Packet>(value), p);
543 Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
544 this->res = value;
545 this->row = Derived::IsRowMajor ? i : i + max_idx;
546 this->col = Derived::IsRowMajor ? j + max_idx : j;
547 }
548};
549
550template <typename Derived, bool is_min, int NaNPropagation>
551struct functor_traits<minmax_coeff_visitor<Derived, is_min, NaNPropagation>> {
552 using Scalar = typename Derived::Scalar;
553 enum { Cost = NumTraits<Scalar>::AddCost, LinearAccess = false, PacketAccess = packet_traits<Scalar>::HasCmp };
554};
555
556template <typename Scalar>
557struct all_visitor {
558 using result_type = bool;
559 using Packet = typename packet_traits<Scalar>::type;
560 EIGEN_DEVICE_FUNC inline void init(const Scalar& value, Index, Index) { res = (value != Scalar(0)); }
561 EIGEN_DEVICE_FUNC inline void init(const Scalar& value, Index) { res = (value != Scalar(0)); }
562 EIGEN_DEVICE_FUNC inline bool all_predux(const Packet& p) const { return !predux_any(pcmp_eq(p, pzero(p))); }
563 EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index, Index) { res = all_predux(p); }
564 EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index) { res = all_predux(p); }
565 EIGEN_DEVICE_FUNC inline void operator()(const Scalar& value, Index, Index) { res = res && (value != Scalar(0)); }
566 EIGEN_DEVICE_FUNC inline void operator()(const Scalar& value, Index) { res = res && (value != Scalar(0)); }
567 EIGEN_DEVICE_FUNC inline void packet(const Packet& p, Index, Index) { res = res && all_predux(p); }
568 EIGEN_DEVICE_FUNC inline void packet(const Packet& p, Index) { res = res && all_predux(p); }
569 EIGEN_DEVICE_FUNC inline bool done() const { return !res; }
570 bool res = true;
571};
572template <typename Scalar>
573struct functor_traits<all_visitor<Scalar>> {
574 enum { Cost = NumTraits<Scalar>::ReadCost, LinearAccess = true, PacketAccess = packet_traits<Scalar>::HasCmp };
575};
576
577template <typename Scalar>
578struct any_visitor {
579 using result_type = bool;
580 using Packet = typename packet_traits<Scalar>::type;
581 EIGEN_DEVICE_FUNC inline void init(const Scalar& value, Index, Index) { res = (value != Scalar(0)); }
582 EIGEN_DEVICE_FUNC inline void init(const Scalar& value, Index) { res = (value != Scalar(0)); }
583 EIGEN_DEVICE_FUNC inline bool any_predux(const Packet& p) const {
584 return predux_any(pandnot(ptrue(p), pcmp_eq(p, pzero(p))));
585 }
586 EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index, Index) { res = any_predux(p); }
587 EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index) { res = any_predux(p); }
588 EIGEN_DEVICE_FUNC inline void operator()(const Scalar& value, Index, Index) { res = res || (value != Scalar(0)); }
589 EIGEN_DEVICE_FUNC inline void operator()(const Scalar& value, Index) { res = res || (value != Scalar(0)); }
590 EIGEN_DEVICE_FUNC inline void packet(const Packet& p, Index, Index) { res = res || any_predux(p); }
591 EIGEN_DEVICE_FUNC inline void packet(const Packet& p, Index) { res = res || any_predux(p); }
592 EIGEN_DEVICE_FUNC inline bool done() const { return res; }
593 bool res = false;
594};
595template <typename Scalar>
596struct functor_traits<any_visitor<Scalar>> {
597 enum { Cost = NumTraits<Scalar>::ReadCost, LinearAccess = true, PacketAccess = packet_traits<Scalar>::HasCmp };
598};
599
600template <typename Scalar>
601struct count_visitor {
602 using result_type = Index;
603 using Packet = typename packet_traits<Scalar>::type;
604 EIGEN_DEVICE_FUNC inline void init(const Scalar& value, Index, Index) { res = value != Scalar(0) ? 1 : 0; }
605 EIGEN_DEVICE_FUNC inline void init(const Scalar& value, Index) { res = value != Scalar(0) ? 1 : 0; }
606 EIGEN_DEVICE_FUNC inline Index count_redux(const Packet& p) const {
607 const Packet cst_one = pset1<Packet>(Scalar(1));
608 Packet true_vals = pandnot(cst_one, pcmp_eq(p, pzero(p)));
609 Scalar num_true = predux(true_vals);
610 return static_cast<Index>(num_true);
611 }
612 EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index, Index) { res = count_redux(p); }
613 EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index) { res = count_redux(p); }
614 EIGEN_DEVICE_FUNC inline void operator()(const Scalar& value, Index, Index) {
615 if (value != Scalar(0)) res++;
616 }
617 EIGEN_DEVICE_FUNC inline void operator()(const Scalar& value, Index) {
618 if (value != Scalar(0)) res++;
619 }
620 EIGEN_DEVICE_FUNC inline void packet(const Packet& p, Index, Index) { res += count_redux(p); }
621 EIGEN_DEVICE_FUNC inline void packet(const Packet& p, Index) { res += count_redux(p); }
622 Index res = 0;
623};
624
625template <typename Scalar>
626struct functor_traits<count_visitor<Scalar>> {
627 enum {
628 Cost = NumTraits<Scalar>::AddCost,
629 LinearAccess = true,
630 // predux is problematic for bool
631 PacketAccess = packet_traits<Scalar>::HasCmp && packet_traits<Scalar>::HasAdd && !is_same<Scalar, bool>::value
632 };
633};
634
635} // end namespace internal
636
648template <typename Derived>
649template <int NaNPropagation, typename IndexType>
650EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar DenseBase<Derived>::minCoeff(IndexType* rowId,
651 IndexType* colId) const {
652 eigen_assert(this->rows() > 0 && this->cols() > 0 && "you are using an empty matrix");
653
654 internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
655 this->visit(minVisitor);
656 *rowId = minVisitor.row;
657 if (colId) *colId = minVisitor.col;
658 return minVisitor.res;
659}
660
672template <typename Derived>
673template <int NaNPropagation, typename IndexType>
674EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar DenseBase<Derived>::minCoeff(IndexType* index) const {
675 eigen_assert(this->rows() > 0 && this->cols() > 0 && "you are using an empty matrix");
676 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
677
678 internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
679 this->visit(minVisitor);
680 *index = IndexType((RowsAtCompileTime == 1) ? minVisitor.col : minVisitor.row);
681 return minVisitor.res;
682}
683
695template <typename Derived>
696template <int NaNPropagation, typename IndexType>
697EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar DenseBase<Derived>::maxCoeff(IndexType* rowPtr,
698 IndexType* colPtr) const {
699 eigen_assert(this->rows() > 0 && this->cols() > 0 && "you are using an empty matrix");
700
701 internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
702 this->visit(maxVisitor);
703 *rowPtr = maxVisitor.row;
704 if (colPtr) *colPtr = maxVisitor.col;
705 return maxVisitor.res;
706}
707
719template <typename Derived>
720template <int NaNPropagation, typename IndexType>
721EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar DenseBase<Derived>::maxCoeff(IndexType* index) const {
722 eigen_assert(this->rows() > 0 && this->cols() > 0 && "you are using an empty matrix");
723
724 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
725 internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
726 this->visit(maxVisitor);
727 *index = (RowsAtCompileTime == 1) ? maxVisitor.col : maxVisitor.row;
728 return maxVisitor.res;
729}
730
738template <typename Derived>
739EIGEN_DEVICE_FUNC inline bool DenseBase<Derived>::all() const {
740 using Visitor = internal::all_visitor<Scalar>;
741 using impl = internal::visit_impl<Derived, Visitor, /*ShortCircuitEvaulation*/ true>;
742 Visitor visitor;
743 impl::run(derived(), visitor);
744 return visitor.res;
745}
746
751template <typename Derived>
752EIGEN_DEVICE_FUNC inline bool DenseBase<Derived>::any() const {
753 using Visitor = internal::any_visitor<Scalar>;
754 using impl = internal::visit_impl<Derived, Visitor, /*ShortCircuitEvaulation*/ true>;
755 Visitor visitor;
756 impl::run(derived(), visitor);
757 return visitor.res;
758}
759
764template <typename Derived>
765EIGEN_DEVICE_FUNC Index DenseBase<Derived>::count() const {
766 using Visitor = internal::count_visitor<Scalar>;
767 using impl = internal::visit_impl<Derived, Visitor, /*ShortCircuitEvaulation*/ false>;
768 Visitor visitor;
769 impl::run(derived(), visitor);
770 return visitor.res;
771}
772
773template <typename Derived>
774EIGEN_DEVICE_FUNC inline bool DenseBase<Derived>::hasNaN() const {
775 return derived().cwiseTypedNotEqual(derived()).any();
776}
777
782template <typename Derived>
783EIGEN_DEVICE_FUNC inline bool DenseBase<Derived>::allFinite() const {
784 return derived().array().isFinite().all();
785}
786
787} // end namespace Eigen
788
789#endif // EIGEN_VISITOR_H
Base class for all dense matrices, vectors, and arrays.
Definition DenseBase.h:44
internal::traits< Derived >::Scalar Scalar
Definition DenseBase.h:62
bool any() const
Definition Visitor.h:752
@ SizeAtCompileTime
Definition DenseBase.h:108
@ IsRowMajor
Definition DenseBase.h:166
@ ColsAtCompileTime
Definition DenseBase.h:102
@ RowsAtCompileTime
Definition DenseBase.h:96
bool all() const
Definition Visitor.h:739
@ PropagateNumbers
Definition Constants.h:344
@ RowMajor
Definition Constants.h:320
const unsigned int PacketAccessBit
Definition Constants.h:97
const unsigned int LinearAccessBit
Definition Constants.h:133
Namespace containing all symbols from the Eigen library.
Definition Core:137
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:83
const int Dynamic
Definition Constants.h:25