10#ifndef EIGEN_VISITOR_H
11#define EIGEN_VISITOR_H
14#include "./InternalHeaderCheck.h"
20template <
typename Visitor,
typename Derived,
int UnrollCount,
21 bool Vectorize = (Derived::PacketAccess && functor_traits<Visitor>::PacketAccess),
bool LinearAccess =
false,
22 bool ShortCircuitEvaluation =
false>
25template <
typename Visitor,
bool ShortCircuitEvaluation = false>
26struct short_circuit_eval_impl {
28 static EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool run(
const Visitor&) {
return false; }
30template <
typename Visitor>
31struct short_circuit_eval_impl<Visitor, true> {
33 static EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool run(
const Visitor& visitor) {
34 return visitor.done();
39template <
typename Visitor,
typename Derived,
int UnrollCount,
bool Vectorize,
bool ShortCircuitEvaluation>
40struct visitor_impl<Visitor, Derived, UnrollCount, Vectorize, false, ShortCircuitEvaluation> {
42 using Scalar =
typename Derived::Scalar;
43 using Packet =
typename packet_traits<Scalar>::type;
44 static constexpr bool RowMajor = Derived::IsRowMajor;
45 static constexpr int RowsAtCompileTime = Derived::RowsAtCompileTime;
46 static constexpr int ColsAtCompileTime = Derived::ColsAtCompileTime;
47 static constexpr int PacketSize = packet_traits<Scalar>::size;
49 static constexpr bool CanVectorize(
int K) {
50 constexpr int InnerSizeAtCompileTime =
RowMajor ? ColsAtCompileTime : RowsAtCompileTime;
51 if (InnerSizeAtCompileTime < PacketSize)
return false;
52 return Vectorize && (InnerSizeAtCompileTime - (K % InnerSizeAtCompileTime) >= PacketSize);
55 template <
int K = 0,
bool Empty = (K == UnrollCount), std::enable_if_t<Empty,
bool> = true>
56 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived&, Visitor&) {}
58 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
59 std::enable_if_t<!Empty && Initialize && !DoVectorOp, bool> =
true>
60 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
61 visitor.init(mat.coeff(0, 0), 0, 0);
65 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
66 std::enable_if_t<!Empty && !Initialize && !DoVectorOp, bool> =
true>
67 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
68 static constexpr int R =
RowMajor ? (K / ColsAtCompileTime) : (K % RowsAtCompileTime);
69 static constexpr int C =
RowMajor ? (K % ColsAtCompileTime) : (K / RowsAtCompileTime);
70 visitor(mat.coeff(R, C), R, C);
71 run<K + 1>(mat, visitor);
74 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
75 std::enable_if_t<!Empty && Initialize && DoVectorOp, bool> =
true>
76 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
77 Packet P = mat.template packet<Packet>(0, 0);
78 visitor.initpacket(P, 0, 0);
79 run<PacketSize>(mat, visitor);
82 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
83 std::enable_if_t<!Empty && !Initialize && DoVectorOp, bool> =
true>
84 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
85 static constexpr int R =
RowMajor ? (K / ColsAtCompileTime) : (K % RowsAtCompileTime);
86 static constexpr int C =
RowMajor ? (K % ColsAtCompileTime) : (K / RowsAtCompileTime);
87 Packet P = mat.template packet<Packet>(R, C);
88 visitor.packet(P, R, C);
89 run<K + PacketSize>(mat, visitor);
94template <
typename Visitor,
typename Derived,
int UnrollCount,
bool Vectorize,
bool ShortCircuitEvaluation>
95struct visitor_impl<Visitor, Derived, UnrollCount, Vectorize, true, ShortCircuitEvaluation> {
97 using Scalar =
typename Derived::Scalar;
98 using Packet =
typename packet_traits<Scalar>::type;
99 static constexpr int PacketSize = packet_traits<Scalar>::size;
101 static constexpr bool CanVectorize(
int K) {
return Vectorize && ((UnrollCount - K) >= PacketSize); }
104 template <
int K = 0,
bool Empty = (K == UnrollCount), std::enable_if_t<Empty,
bool> = true>
105 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived&, Visitor&) {}
108 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
109 std::enable_if_t<!Empty && Initialize && !DoVectorOp, bool> =
true>
110 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
111 visitor.init(mat.coeff(0), 0);
112 run<1>(mat, visitor);
116 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
117 std::enable_if_t<!Empty && !Initialize && !DoVectorOp, bool> =
true>
118 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
119 visitor(mat.coeff(K), K);
120 run<K + 1>(mat, visitor);
124 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
125 std::enable_if_t<!Empty && Initialize && DoVectorOp, bool> =
true>
126 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
127 Packet P = mat.template packet<Packet>(0);
128 visitor.initpacket(P, 0);
129 run<PacketSize>(mat, visitor);
133 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
134 std::enable_if_t<!Empty && !Initialize && DoVectorOp, bool> =
true>
135 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
136 Packet P = mat.template packet<Packet>(K);
137 visitor.packet(P, K);
138 run<K + PacketSize>(mat, visitor);
143template <
typename Visitor,
typename Derived,
bool ShortCircuitEvaluation>
144struct visitor_impl<Visitor, Derived,
Dynamic, false, false, ShortCircuitEvaluation> {
145 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
146 static constexpr bool RowMajor = Derived::IsRowMajor;
148 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
149 const Index innerSize =
RowMajor ? mat.cols() : mat.rows();
150 const Index outerSize =
RowMajor ? mat.rows() : mat.cols();
151 if (innerSize == 0 || outerSize == 0)
return;
153 visitor.init(mat.coeff(0, 0), 0, 0);
154 if (short_circuit::run(visitor))
return;
155 for (Index i = 1; i < innerSize; ++i) {
158 visitor(mat.coeff(r, c), r, c);
159 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
162 for (Index j = 1; j < outerSize; j++) {
163 for (Index i = 0; i < innerSize; ++i) {
166 visitor(mat.coeff(r, c), r, c);
167 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
174template <
typename Visitor,
typename Derived,
bool ShortCircuitEvaluation>
175struct visitor_impl<Visitor, Derived,
Dynamic, true, false, ShortCircuitEvaluation> {
176 using Scalar =
typename Derived::Scalar;
177 using Packet =
typename packet_traits<Scalar>::type;
178 static constexpr int PacketSize = packet_traits<Scalar>::size;
179 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
180 static constexpr bool RowMajor = Derived::IsRowMajor;
182 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
183 const Index innerSize =
RowMajor ? mat.cols() : mat.rows();
184 const Index outerSize =
RowMajor ? mat.rows() : mat.cols();
185 if (innerSize == 0 || outerSize == 0)
return;
188 if (innerSize < PacketSize) {
189 visitor.init(mat.coeff(0, 0), 0, 0);
192 Packet p = mat.template packet<Packet>(0, 0);
193 visitor.initpacket(p, 0, 0);
196 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
197 for (; i + PacketSize - 1 < innerSize; i += PacketSize) {
200 Packet p = mat.template packet<Packet>(r, c);
201 visitor.packet(p, r, c);
202 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
204 for (; i < innerSize; ++i) {
207 visitor(mat.coeff(r, c), r, c);
208 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
211 for (Index j = 1; j < outerSize; j++) {
213 for (; i + PacketSize - 1 < innerSize; i += PacketSize) {
216 Packet p = mat.template packet<Packet>(r, c);
217 visitor.packet(p, r, c);
218 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
220 for (; i < innerSize; ++i) {
223 visitor(mat.coeff(r, c), r, c);
224 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
231template <
typename Visitor,
typename Derived,
bool ShortCircuitEvaluation>
232struct visitor_impl<Visitor, Derived,
Dynamic, false, true, ShortCircuitEvaluation> {
233 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
235 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
236 const Index size = mat.size();
237 if (size == 0)
return;
238 visitor.init(mat.coeff(0), 0);
239 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
240 for (Index k = 1; k < size; k++) {
241 visitor(mat.coeff(k), k);
242 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
248template <
typename Visitor,
typename Derived,
bool ShortCircuitEvaluation>
249struct visitor_impl<Visitor, Derived,
Dynamic, true, true, ShortCircuitEvaluation> {
250 using Scalar =
typename Derived::Scalar;
251 using Packet =
typename packet_traits<Scalar>::type;
252 static constexpr int PacketSize = packet_traits<Scalar>::size;
253 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
255 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
256 const Index size = mat.size();
257 if (size == 0)
return;
259 if (size < PacketSize) {
260 visitor.init(mat.coeff(0), 0);
263 Packet p = mat.template packet<Packet>(k);
264 visitor.initpacket(p, k);
267 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
268 for (; k + PacketSize - 1 < size; k += PacketSize) {
269 Packet p = mat.template packet<Packet>(k);
270 visitor.packet(p, k);
271 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
273 for (; k < size; k++) {
274 visitor(mat.coeff(k), k);
275 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
281template <
typename XprType>
282class visitor_evaluator {
284 typedef evaluator<XprType> Evaluator;
285 typedef typename XprType::Scalar Scalar;
286 using Packet =
typename packet_traits<Scalar>::type;
287 typedef std::remove_const_t<typename XprType::CoeffReturnType> CoeffReturnType;
289 static constexpr bool PacketAccess =
static_cast<bool>(Evaluator::Flags &
PacketAccessBit);
290 static constexpr bool LinearAccess =
static_cast<bool>(Evaluator::Flags &
LinearAccessBit);
291 static constexpr bool IsRowMajor =
static_cast<bool>(XprType::IsRowMajor);
292 static constexpr int RowsAtCompileTime = XprType::RowsAtCompileTime;
293 static constexpr int ColsAtCompileTime = XprType::ColsAtCompileTime;
294 static constexpr int XprAlignment = Evaluator::Alignment;
295 static constexpr int CoeffReadCost = Evaluator::CoeffReadCost;
297 EIGEN_DEVICE_FUNC
explicit visitor_evaluator(
const XprType& xpr) : m_evaluator(xpr), m_xpr(xpr) {}
299 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT {
return m_xpr.rows(); }
300 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT {
return m_xpr.cols(); }
301 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index size() const EIGEN_NOEXCEPT {
return m_xpr.size(); }
303 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col)
const {
304 return m_evaluator.coeff(row, col);
306 template <
typename Packet,
int Alignment = Unaligned>
307 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(Index row, Index col)
const {
308 return m_evaluator.template packet<Alignment, Packet>(row, col);
311 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
return m_evaluator.coeff(index); }
312 template <
typename Packet,
int Alignment = XprAlignment>
313 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(Index index)
const {
314 return m_evaluator.template packet<Alignment, Packet>(index);
318 Evaluator m_evaluator;
319 const XprType& m_xpr;
322template <
typename Derived,
typename Visitor,
bool ShortCircuitEvaulation>
324 using Evaluator = visitor_evaluator<Derived>;
331 static constexpr int InnerSizeAtCompileTime = IsRowMajor ? ColsAtCompileTime : RowsAtCompileTime;
332 static constexpr int OuterSizeAtCompileTime = IsRowMajor ? RowsAtCompileTime : ColsAtCompileTime;
334 static constexpr bool LinearAccess =
335 Evaluator::LinearAccess &&
static_cast<bool>(functor_traits<Visitor>::LinearAccess);
336 static constexpr bool Vectorize = Evaluator::PacketAccess &&
static_cast<bool>(functor_traits<Visitor>::PacketAccess);
338 static constexpr int PacketSize = packet_traits<Scalar>::size;
339 static constexpr int VectorOps =
340 Vectorize ? (LinearAccess ? (SizeAtCompileTime / PacketSize)
341 : (OuterSizeAtCompileTime * (InnerSizeAtCompileTime / PacketSize)))
343 static constexpr int ScalarOps = SizeAtCompileTime - (VectorOps * PacketSize);
345 static constexpr int TotalOps = VectorOps + ScalarOps;
347 static constexpr int UnrollCost = int(Evaluator::CoeffReadCost) + int(functor_traits<Visitor>::Cost);
348 static constexpr bool Unroll = (SizeAtCompileTime !=
Dynamic) && ((TotalOps * UnrollCost) <= EIGEN_UNROLLING_LIMIT);
349 static constexpr int UnrollCount = Unroll ? int(SizeAtCompileTime) :
Dynamic;
351 using impl = visitor_impl<Visitor, Evaluator, UnrollCount, Vectorize, LinearAccess, ShortCircuitEvaulation>;
353 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const DenseBase<Derived>& mat, Visitor& visitor) {
354 Evaluator evaluator(mat.derived());
355 impl::run(evaluator, visitor);
380template <
typename Derived>
381template <
typename Visitor>
383 using impl = internal::visit_impl<Derived, Visitor,
false>;
384 impl::run(derived(), visitor);
392template <
typename Derived>
393struct coeff_visitor {
395 EIGEN_DEVICE_FUNC coeff_visitor() : row(-1), col(-1), res(0) {}
396 typedef typename Derived::Scalar Scalar;
399 EIGEN_DEVICE_FUNC
inline void init(
const Scalar& value, Index i, Index j) {
406template <
typename Scalar,
int NaNPropagation,
bool is_min = true>
407struct minmax_compare {
408 typedef typename packet_traits<Scalar>::type Packet;
409 static EIGEN_DEVICE_FUNC
inline bool compare(Scalar a, Scalar b) {
return a < b; }
410 static EIGEN_DEVICE_FUNC
inline Scalar predux(
const Packet& p) {
return predux_min<NaNPropagation>(p); }
413template <
typename Scalar,
int NaNPropagation>
414struct minmax_compare<Scalar, NaNPropagation, false> {
415 typedef typename packet_traits<Scalar>::type Packet;
416 static EIGEN_DEVICE_FUNC
inline bool compare(Scalar a, Scalar b) {
return a > b; }
417 static EIGEN_DEVICE_FUNC
inline Scalar predux(
const Packet& p) {
return predux_max<NaNPropagation>(p); }
422template <
typename Derived,
bool is_min,
int NaNPropagation,
423 bool isInt = NumTraits<typename Derived::Scalar>::IsInteger>
424struct minmax_coeff_visitor : coeff_visitor<Derived> {
425 using Scalar =
typename Derived::Scalar;
426 using Packet =
typename packet_traits<Scalar>::type;
427 using Comparator = minmax_compare<Scalar, NaNPropagation, is_min>;
428 static constexpr Index PacketSize = packet_traits<Scalar>::size;
430 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value, Index i, Index j) {
431 if (Comparator::compare(value, this->res)) {
437 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p,
Index i,
Index j) {
438 Scalar value = Comparator::predux(p);
439 if (Comparator::compare(value, this->res)) {
440 const Packet range = preverse(plset<Packet>(
Scalar(1)));
441 Packet mask = pcmp_eq(pset1<Packet>(value), p);
442 Index max_idx = PacketSize -
static_cast<Index>(predux_max(pand(range, mask)));
444 this->row = Derived::IsRowMajor ? i : i + max_idx;
445 this->col = Derived::IsRowMajor ? j + max_idx : j;
448 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p, Index i, Index j) {
449 Scalar value = Comparator::predux(p);
450 const Packet range = preverse(plset<Packet>(Scalar(1)));
451 Packet mask = pcmp_eq(pset1<Packet>(value), p);
452 Index max_idx = PacketSize -
static_cast<Index
>(predux_max(pand(range, mask)));
454 this->row = Derived::IsRowMajor ? i : i + max_idx;
455 this->col = Derived::IsRowMajor ? j + max_idx : j;
461template <
typename Derived,
bool is_min>
462struct minmax_coeff_visitor<Derived, is_min,
PropagateNumbers, false> : coeff_visitor<Derived> {
463 typedef typename Derived::Scalar Scalar;
464 using Packet =
typename packet_traits<Scalar>::type;
465 using Comparator = minmax_compare<Scalar, PropagateNumbers, is_min>;
468 if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
474 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p,
Index i,
Index j) {
475 const Index PacketSize = packet_traits<Scalar>::size;
476 Scalar value = Comparator::predux(p);
477 if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
478 const Packet range = preverse(plset<Packet>(Scalar(1)));
480 Packet mask = pcmp_eq(pset1<Packet>(value), p);
481 Index max_idx = PacketSize -
static_cast<Index>(predux_max(pand(range, mask)));
483 this->row = Derived::IsRowMajor ? i : i + max_idx;
484 this->col = Derived::IsRowMajor ? j + max_idx : j;
487 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p,
Index i,
Index j) {
488 const Index PacketSize = packet_traits<Scalar>::size;
489 Scalar value = Comparator::predux(p);
490 if ((numext::isnan)(value)) {
496 const Packet range = preverse(plset<Packet>(Scalar(1)));
498 Packet mask = pcmp_eq(pset1<Packet>(value), p);
499 Index max_idx = PacketSize -
static_cast<Index
>(predux_max(pand(range, mask)));
501 this->row = Derived::IsRowMajor ? i : i + max_idx;
502 this->col = Derived::IsRowMajor ? j + max_idx : j;
508template <
typename Derived,
bool is_min,
int NaNPropagation>
509struct minmax_coeff_visitor<Derived, is_min, NaNPropagation, false> : coeff_visitor<Derived> {
510 typedef typename Derived::Scalar Scalar;
511 using Packet =
typename packet_traits<Scalar>::type;
512 using Comparator = minmax_compare<Scalar, PropagateNaN, is_min>;
514 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value, Index i, Index j) {
515 const bool value_is_nan = (numext::isnan)(value);
516 if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
522 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p, Index i, Index j) {
523 const Index PacketSize = packet_traits<Scalar>::size;
524 Scalar value = Comparator::predux(p);
525 const bool value_is_nan = (numext::isnan)(value);
526 if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
527 const Packet range = preverse(plset<Packet>(Scalar(1)));
529 Packet mask = value_is_nan ? pnot(pcmp_eq(p, p)) : pcmp_eq(pset1<Packet>(value), p);
530 Index max_idx = PacketSize -
static_cast<Index
>(predux_max(pand(range, mask)));
532 this->row = Derived::IsRowMajor ? i : i + max_idx;
533 this->col = Derived::IsRowMajor ? j + max_idx : j;
536 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p, Index i, Index j) {
537 const Index PacketSize = packet_traits<Scalar>::size;
538 Scalar value = Comparator::predux(p);
539 const bool value_is_nan = (numext::isnan)(value);
540 const Packet range = preverse(plset<Packet>(Scalar(1)));
542 Packet mask = value_is_nan ? pnot(pcmp_eq(p, p)) : pcmp_eq(pset1<Packet>(value), p);
543 Index max_idx = PacketSize -
static_cast<Index
>(predux_max(pand(range, mask)));
545 this->row = Derived::IsRowMajor ? i : i + max_idx;
546 this->col = Derived::IsRowMajor ? j + max_idx : j;
550template <
typename Derived,
bool is_min,
int NaNPropagation>
551struct functor_traits<minmax_coeff_visitor<Derived, is_min, NaNPropagation>> {
552 using Scalar =
typename Derived::Scalar;
553 enum { Cost = NumTraits<Scalar>::AddCost, LinearAccess =
false, PacketAccess = packet_traits<Scalar>::HasCmp };
556template <
typename Scalar>
558 using result_type = bool;
559 using Packet =
typename packet_traits<Scalar>::type;
560 EIGEN_DEVICE_FUNC
inline void init(
const Scalar& value, Index, Index) { res = (value != Scalar(0)); }
561 EIGEN_DEVICE_FUNC
inline void init(
const Scalar& value, Index) { res = (value != Scalar(0)); }
562 EIGEN_DEVICE_FUNC
inline bool all_predux(
const Packet& p)
const {
return !predux_any(pcmp_eq(p, pzero(p))); }
563 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p, Index, Index) { res = all_predux(p); }
564 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p, Index) { res = all_predux(p); }
565 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value, Index, Index) { res = res && (value != Scalar(0)); }
566 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value, Index) { res = res && (value != Scalar(0)); }
567 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p, Index, Index) { res = res && all_predux(p); }
568 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p, Index) { res = res && all_predux(p); }
569 EIGEN_DEVICE_FUNC
inline bool done()
const {
return !res; }
572template <
typename Scalar>
573struct functor_traits<all_visitor<Scalar>> {
574 enum { Cost = NumTraits<Scalar>::ReadCost, LinearAccess =
true, PacketAccess = packet_traits<Scalar>::HasCmp };
577template <
typename Scalar>
579 using result_type = bool;
580 using Packet =
typename packet_traits<Scalar>::type;
581 EIGEN_DEVICE_FUNC
inline void init(
const Scalar& value, Index, Index) { res = (value != Scalar(0)); }
582 EIGEN_DEVICE_FUNC
inline void init(
const Scalar& value, Index) { res = (value != Scalar(0)); }
583 EIGEN_DEVICE_FUNC
inline bool any_predux(
const Packet& p)
const {
584 return predux_any(pandnot(ptrue(p), pcmp_eq(p, pzero(p))));
586 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p, Index, Index) { res = any_predux(p); }
587 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p, Index) { res = any_predux(p); }
588 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value, Index, Index) { res = res || (value != Scalar(0)); }
589 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value, Index) { res = res || (value != Scalar(0)); }
590 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p, Index, Index) { res = res || any_predux(p); }
591 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p, Index) { res = res || any_predux(p); }
592 EIGEN_DEVICE_FUNC
inline bool done()
const {
return res; }
595template <
typename Scalar>
596struct functor_traits<any_visitor<Scalar>> {
597 enum { Cost = NumTraits<Scalar>::ReadCost, LinearAccess =
true, PacketAccess = packet_traits<Scalar>::HasCmp };
600template <
typename Scalar>
601struct count_visitor {
602 using result_type =
Index;
603 using Packet =
typename packet_traits<Scalar>::type;
604 EIGEN_DEVICE_FUNC
inline void init(
const Scalar& value, Index, Index) { res = value != Scalar(0) ? 1 : 0; }
605 EIGEN_DEVICE_FUNC
inline void init(
const Scalar& value, Index) { res = value != Scalar(0) ? 1 : 0; }
606 EIGEN_DEVICE_FUNC
inline Index count_redux(
const Packet& p)
const {
607 const Packet cst_one = pset1<Packet>(Scalar(1));
608 Packet true_vals = pandnot(cst_one, pcmp_eq(p, pzero(p)));
609 Scalar num_true = predux(true_vals);
610 return static_cast<Index
>(num_true);
612 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p, Index, Index) { res = count_redux(p); }
613 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p, Index) { res = count_redux(p); }
614 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value, Index, Index) {
615 if (value != Scalar(0)) res++;
617 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value, Index) {
618 if (value != Scalar(0)) res++;
620 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p, Index, Index) { res += count_redux(p); }
621 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p, Index) { res += count_redux(p); }
625template <
typename Scalar>
626struct functor_traits<count_visitor<Scalar>> {
628 Cost = NumTraits<Scalar>::AddCost,
631 PacketAccess = packet_traits<Scalar>::HasCmp && packet_traits<Scalar>::HasAdd && !is_same<Scalar, bool>::value
648template <
typename Derived>
649template <
int NaNPropagation,
typename IndexType>
651 IndexType* colId)
const {
652 eigen_assert(this->rows() > 0 && this->cols() > 0 &&
"you are using an empty matrix");
654 internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
655 this->visit(minVisitor);
656 *rowId = minVisitor.row;
657 if (colId) *colId = minVisitor.col;
658 return minVisitor.res;
672template <
typename Derived>
673template <
int NaNPropagation,
typename IndexType>
675 eigen_assert(this->rows() > 0 && this->cols() > 0 &&
"you are using an empty matrix");
676 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
678 internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
679 this->visit(minVisitor);
680 *index = IndexType((RowsAtCompileTime == 1) ? minVisitor.col : minVisitor.row);
681 return minVisitor.res;
695template <
typename Derived>
696template <
int NaNPropagation,
typename IndexType>
698 IndexType* colPtr)
const {
699 eigen_assert(this->rows() > 0 && this->cols() > 0 &&
"you are using an empty matrix");
701 internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
702 this->visit(maxVisitor);
703 *rowPtr = maxVisitor.row;
704 if (colPtr) *colPtr = maxVisitor.col;
705 return maxVisitor.res;
719template <
typename Derived>
720template <
int NaNPropagation,
typename IndexType>
722 eigen_assert(this->rows() > 0 && this->cols() > 0 &&
"you are using an empty matrix");
724 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
725 internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
726 this->visit(maxVisitor);
727 *index = (RowsAtCompileTime == 1) ? maxVisitor.col : maxVisitor.row;
728 return maxVisitor.res;
738template <
typename Derived>
740 using Visitor = internal::all_visitor<Scalar>;
741 using impl = internal::visit_impl<Derived, Visitor,
true>;
743 impl::run(derived(), visitor);
751template <
typename Derived>
753 using Visitor = internal::any_visitor<Scalar>;
754 using impl = internal::visit_impl<Derived, Visitor,
true>;
756 impl::run(derived(), visitor);
764template <
typename Derived>
766 using Visitor = internal::count_visitor<Scalar>;
767 using impl = internal::visit_impl<Derived, Visitor,
false>;
769 impl::run(derived(), visitor);
773template <
typename Derived>
775 return derived().cwiseTypedNotEqual(derived()).
any();
782template <
typename Derived>
784 return derived().array().isFinite().
all();
Base class for all dense matrices, vectors, and arrays.
Definition DenseBase.h:44
internal::traits< Derived >::Scalar Scalar
Definition DenseBase.h:62
bool any() const
Definition Visitor.h:752
@ SizeAtCompileTime
Definition DenseBase.h:108
@ IsRowMajor
Definition DenseBase.h:166
@ ColsAtCompileTime
Definition DenseBase.h:102
@ RowsAtCompileTime
Definition DenseBase.h:96
bool all() const
Definition Visitor.h:739
@ PropagateNumbers
Definition Constants.h:344
@ RowMajor
Definition Constants.h:320
const unsigned int PacketAccessBit
Definition Constants.h:97
const unsigned int LinearAccessBit
Definition Constants.h:133
Namespace containing all symbols from the Eigen library.
Definition Core:137
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:83
const int Dynamic
Definition Constants.h:25