Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
ProductEvaluators.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2006-2008 Benoit Jacob <[email protected]>
5// Copyright (C) 2008-2010 Gael Guennebaud <[email protected]>
6// Copyright (C) 2011 Jitse Niesen <[email protected]>
7//
8// This Source Code Form is subject to the terms of the Mozilla
9// Public License v. 2.0. If a copy of the MPL was not distributed
10// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11
12#ifndef EIGEN_PRODUCTEVALUATORS_H
13#define EIGEN_PRODUCTEVALUATORS_H
14
15// IWYU pragma: private
16#include "./InternalHeaderCheck.h"
17
18namespace Eigen {
19
20namespace internal {
21
30template <typename Lhs, typename Rhs, int Options>
31struct evaluator<Product<Lhs, Rhs, Options>> : public product_evaluator<Product<Lhs, Rhs, Options>> {
32 typedef Product<Lhs, Rhs, Options> XprType;
33 typedef product_evaluator<XprType> Base;
34
35 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit evaluator(const XprType& xpr) : Base(xpr) {}
36};
37
38// Catch "scalar * ( A * B )" and transform it to "(A*scalar) * B"
39// TODO we should apply that rule only if that's really helpful
40template <typename Lhs, typename Rhs, typename Scalar1, typename Scalar2, typename Plain1>
41struct evaluator_assume_aliasing<CwiseBinaryOp<internal::scalar_product_op<Scalar1, Scalar2>,
42 const CwiseNullaryOp<internal::scalar_constant_op<Scalar1>, Plain1>,
43 const Product<Lhs, Rhs, DefaultProduct>>> {
44 static const bool value = true;
45};
46template <typename Lhs, typename Rhs, typename Scalar1, typename Scalar2, typename Plain1>
47struct evaluator<CwiseBinaryOp<internal::scalar_product_op<Scalar1, Scalar2>,
48 const CwiseNullaryOp<internal::scalar_constant_op<Scalar1>, Plain1>,
49 const Product<Lhs, Rhs, DefaultProduct>>>
50 : public evaluator<Product<EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar1, Lhs, product), Rhs, DefaultProduct>> {
51 typedef CwiseBinaryOp<internal::scalar_product_op<Scalar1, Scalar2>,
52 const CwiseNullaryOp<internal::scalar_constant_op<Scalar1>, Plain1>,
53 const Product<Lhs, Rhs, DefaultProduct>>
54 XprType;
55 typedef evaluator<Product<EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar1, Lhs, product), Rhs, DefaultProduct>> Base;
56
57 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit evaluator(const XprType& xpr)
58 : Base(xpr.lhs().functor().m_other * xpr.rhs().lhs() * xpr.rhs().rhs()) {}
59};
60
61template <typename Lhs, typename Rhs, int DiagIndex>
62struct evaluator<Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex>>
63 : public evaluator<Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex>> {
64 typedef Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> XprType;
65 typedef evaluator<Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex>> Base;
66
67 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit evaluator(const XprType& xpr)
68 : Base(Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex>(
69 Product<Lhs, Rhs, LazyProduct>(xpr.nestedExpression().lhs(), xpr.nestedExpression().rhs()), xpr.index())) {}
70};
71
72// Helper class to perform a matrix product with the destination at hand.
73// Depending on the sizes of the factors, there are different evaluation strategies
74// as controlled by internal::product_type.
75template <typename Lhs, typename Rhs, typename LhsShape = typename evaluator_traits<Lhs>::Shape,
76 typename RhsShape = typename evaluator_traits<Rhs>::Shape,
77 int ProductType = internal::product_type<Lhs, Rhs>::value>
78struct generic_product_impl;
79
80template <typename Lhs, typename Rhs>
81struct evaluator_assume_aliasing<Product<Lhs, Rhs, DefaultProduct>> {
82 static const bool value = true;
83};
84
85// This is the default evaluator implementation for products:
86// It creates a temporary and call generic_product_impl
87template <typename Lhs, typename Rhs, int Options, int ProductTag, typename LhsShape, typename RhsShape>
88struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, LhsShape, RhsShape>
89 : public evaluator<typename Product<Lhs, Rhs, Options>::PlainObject> {
90 typedef Product<Lhs, Rhs, Options> XprType;
91 typedef typename XprType::PlainObject PlainObject;
92 typedef evaluator<PlainObject> Base;
93 enum { Flags = Base::Flags | EvalBeforeNestingBit };
94
95 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit product_evaluator(const XprType& xpr)
96 : m_result(xpr.rows(), xpr.cols()) {
97 internal::construct_at<Base>(this, m_result);
98
99 // FIXME shall we handle nested_eval here?,
100 // if so, then we must take care at removing the call to nested_eval in the specializations (e.g., in
101 // permutation_matrix_product, transposition_matrix_product, etc.)
102 // typedef typename internal::nested_eval<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
103 // typedef typename internal::nested_eval<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
104 // typedef internal::remove_all_t<LhsNested> LhsNestedCleaned;
105 // typedef internal::remove_all_t<RhsNested> RhsNestedCleaned;
106 //
107 // const LhsNested lhs(xpr.lhs());
108 // const RhsNested rhs(xpr.rhs());
109 //
110 // generic_product_impl<LhsNestedCleaned, RhsNestedCleaned>::evalTo(m_result, lhs, rhs);
111
112 generic_product_impl<Lhs, Rhs, LhsShape, RhsShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
113 }
114
115 protected:
116 PlainObject m_result;
117};
118
119// The following three shortcuts are enabled only if the scalar types match exactly.
120// TODO: we could enable them for different scalar types when the product is not vectorized.
121
122// Dense = Product
123template <typename DstXprType, typename Lhs, typename Rhs, int Options, typename Scalar>
124struct Assignment<DstXprType, Product<Lhs, Rhs, Options>, internal::assign_op<Scalar, Scalar>, Dense2Dense,
125 std::enable_if_t<(Options == DefaultProduct || Options == AliasFreeProduct)>> {
126 typedef Product<Lhs, Rhs, Options> SrcXprType;
127 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(DstXprType& dst, const SrcXprType& src,
128 const internal::assign_op<Scalar, Scalar>&) {
129 Index dstRows = src.rows();
130 Index dstCols = src.cols();
131 if ((dst.rows() != dstRows) || (dst.cols() != dstCols)) dst.resize(dstRows, dstCols);
132 // FIXME shall we handle nested_eval here?
133 generic_product_impl<Lhs, Rhs>::evalTo(dst, src.lhs(), src.rhs());
134 }
135};
136
137// Dense += Product
138template <typename DstXprType, typename Lhs, typename Rhs, int Options, typename Scalar>
139struct Assignment<DstXprType, Product<Lhs, Rhs, Options>, internal::add_assign_op<Scalar, Scalar>, Dense2Dense,
140 std::enable_if_t<(Options == DefaultProduct || Options == AliasFreeProduct)>> {
141 typedef Product<Lhs, Rhs, Options> SrcXprType;
142 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(DstXprType& dst, const SrcXprType& src,
143 const internal::add_assign_op<Scalar, Scalar>&) {
144 eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
145 // FIXME shall we handle nested_eval here?
146 generic_product_impl<Lhs, Rhs>::addTo(dst, src.lhs(), src.rhs());
147 }
148};
149
150// Dense -= Product
151template <typename DstXprType, typename Lhs, typename Rhs, int Options, typename Scalar>
152struct Assignment<DstXprType, Product<Lhs, Rhs, Options>, internal::sub_assign_op<Scalar, Scalar>, Dense2Dense,
153 std::enable_if_t<(Options == DefaultProduct || Options == AliasFreeProduct)>> {
154 typedef Product<Lhs, Rhs, Options> SrcXprType;
155 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(DstXprType& dst, const SrcXprType& src,
156 const internal::sub_assign_op<Scalar, Scalar>&) {
157 eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
158 // FIXME shall we handle nested_eval here?
159 generic_product_impl<Lhs, Rhs>::subTo(dst, src.lhs(), src.rhs());
160 }
161};
162
163// Dense ?= scalar * Product
164// TODO we should apply that rule if that's really helpful
165// for instance, this is not good for inner products
166template <typename DstXprType, typename Lhs, typename Rhs, typename AssignFunc, typename Scalar, typename ScalarBis,
167 typename Plain>
168struct Assignment<DstXprType,
169 CwiseBinaryOp<internal::scalar_product_op<ScalarBis, Scalar>,
170 const CwiseNullaryOp<internal::scalar_constant_op<ScalarBis>, Plain>,
171 const Product<Lhs, Rhs, DefaultProduct>>,
172 AssignFunc, Dense2Dense> {
173 typedef CwiseBinaryOp<internal::scalar_product_op<ScalarBis, Scalar>,
174 const CwiseNullaryOp<internal::scalar_constant_op<ScalarBis>, Plain>,
175 const Product<Lhs, Rhs, DefaultProduct>>
176 SrcXprType;
177 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(DstXprType& dst, const SrcXprType& src,
178 const AssignFunc& func) {
179 call_assignment_no_alias(dst, (src.lhs().functor().m_other * src.rhs().lhs()) * src.rhs().rhs(), func);
180 }
181};
182
183//----------------------------------------
184// Catch "Dense ?= xpr + Product<>" expression to save one temporary
185// FIXME we could probably enable these rules for any product, i.e., not only Dense and DefaultProduct
186
187template <typename OtherXpr, typename Lhs, typename Rhs>
188struct evaluator_assume_aliasing<
189 CwiseBinaryOp<
190 internal::scalar_sum_op<typename OtherXpr::Scalar, typename Product<Lhs, Rhs, DefaultProduct>::Scalar>,
191 const OtherXpr, const Product<Lhs, Rhs, DefaultProduct>>,
192 DenseShape> {
193 static const bool value = true;
194};
195
196template <typename OtherXpr, typename Lhs, typename Rhs>
197struct evaluator_assume_aliasing<
198 CwiseBinaryOp<
199 internal::scalar_difference_op<typename OtherXpr::Scalar, typename Product<Lhs, Rhs, DefaultProduct>::Scalar>,
200 const OtherXpr, const Product<Lhs, Rhs, DefaultProduct>>,
201 DenseShape> {
202 static const bool value = true;
203};
204
205template <typename DstXprType, typename OtherXpr, typename ProductType, typename Func1, typename Func2>
206struct assignment_from_xpr_op_product {
207 template <typename SrcXprType, typename InitialFunc>
208 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(DstXprType& dst, const SrcXprType& src,
209 const InitialFunc& /*func*/) {
210 call_assignment_no_alias(dst, src.lhs(), Func1());
211 call_assignment_no_alias(dst, src.rhs(), Func2());
212 }
213};
214
215#define EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(ASSIGN_OP, BINOP, ASSIGN_OP2) \
216 template <typename DstXprType, typename OtherXpr, typename Lhs, typename Rhs, typename DstScalar, \
217 typename SrcScalar, typename OtherScalar, typename ProdScalar> \
218 struct Assignment<DstXprType, \
219 CwiseBinaryOp<internal::BINOP<OtherScalar, ProdScalar>, const OtherXpr, \
220 const Product<Lhs, Rhs, DefaultProduct>>, \
221 internal::ASSIGN_OP<DstScalar, SrcScalar>, Dense2Dense> \
222 : assignment_from_xpr_op_product<DstXprType, OtherXpr, Product<Lhs, Rhs, DefaultProduct>, \
223 internal::ASSIGN_OP<DstScalar, OtherScalar>, \
224 internal::ASSIGN_OP2<DstScalar, ProdScalar>> {}
225
226EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(assign_op, scalar_sum_op, add_assign_op);
227EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(add_assign_op, scalar_sum_op, add_assign_op);
228EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(sub_assign_op, scalar_sum_op, sub_assign_op);
229
230EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(assign_op, scalar_difference_op, sub_assign_op);
231EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(add_assign_op, scalar_difference_op, sub_assign_op);
232EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(sub_assign_op, scalar_difference_op, add_assign_op);
233
234//----------------------------------------
235
236template <typename Lhs, typename Rhs>
237struct generic_product_impl<Lhs, Rhs, DenseShape, DenseShape, InnerProduct> {
238 template <typename Dst>
239 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
240 dst.coeffRef(0, 0) = (lhs.transpose().cwiseProduct(rhs)).sum();
241 }
242
243 template <typename Dst>
244 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
245 dst.coeffRef(0, 0) += (lhs.transpose().cwiseProduct(rhs)).sum();
246 }
247
248 template <typename Dst>
249 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
250 dst.coeffRef(0, 0) -= (lhs.transpose().cwiseProduct(rhs)).sum();
251 }
252};
253
254/***********************************************************************
255 * Implementation of outer dense * dense vector product
256 ***********************************************************************/
257
258// Column major result
259template <typename Dst, typename Lhs, typename Rhs, typename Func>
260void EIGEN_DEVICE_FUNC outer_product_selector_run(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Func& func,
261 const false_type&) {
262 evaluator<Rhs> rhsEval(rhs);
263 ei_declare_local_nested_eval(Lhs, lhs, Rhs::SizeAtCompileTime, actual_lhs);
264 // FIXME if cols is large enough, then it might be useful to make sure that lhs is sequentially stored
265 // FIXME not very good if rhs is real and lhs complex while alpha is real too
266 const Index cols = dst.cols();
267 for (Index j = 0; j < cols; ++j) func(dst.col(j), rhsEval.coeff(Index(0), j) * actual_lhs);
268}
269
270// Row major result
271template <typename Dst, typename Lhs, typename Rhs, typename Func>
272void EIGEN_DEVICE_FUNC outer_product_selector_run(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Func& func,
273 const true_type&) {
274 evaluator<Lhs> lhsEval(lhs);
275 ei_declare_local_nested_eval(Rhs, rhs, Lhs::SizeAtCompileTime, actual_rhs);
276 // FIXME if rows is large enough, then it might be useful to make sure that rhs is sequentially stored
277 // FIXME not very good if lhs is real and rhs complex while alpha is real too
278 const Index rows = dst.rows();
279 for (Index i = 0; i < rows; ++i) func(dst.row(i), lhsEval.coeff(i, Index(0)) * actual_rhs);
280}
281
282template <typename Lhs, typename Rhs>
283struct generic_product_impl<Lhs, Rhs, DenseShape, DenseShape, OuterProduct> {
284 template <typename T>
285 struct is_row_major : std::conditional_t<(int(T::Flags) & RowMajorBit), internal::true_type, internal::false_type> {};
286 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
287
288 // TODO it would be nice to be able to exploit our *_assign_op functors for that purpose
289 struct set {
290 template <typename Dst, typename Src>
291 EIGEN_DEVICE_FUNC void operator()(const Dst& dst, const Src& src) const {
292 dst.const_cast_derived() = src;
293 }
294 };
295 struct add {
296 template <typename Dst, typename Src>
297 EIGEN_DEVICE_FUNC void operator()(const Dst& dst, const Src& src) const {
298 dst.const_cast_derived() += src;
299 }
300 };
301 struct sub {
302 template <typename Dst, typename Src>
303 EIGEN_DEVICE_FUNC void operator()(const Dst& dst, const Src& src) const {
304 dst.const_cast_derived() -= src;
305 }
306 };
307 struct adds {
308 Scalar m_scale;
309 explicit adds(const Scalar& s) : m_scale(s) {}
310 template <typename Dst, typename Src>
311 void EIGEN_DEVICE_FUNC operator()(const Dst& dst, const Src& src) const {
312 dst.const_cast_derived() += m_scale * src;
313 }
314 };
315
316 template <typename Dst>
317 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
318 internal::outer_product_selector_run(dst, lhs, rhs, set(), is_row_major<Dst>());
319 }
320
321 template <typename Dst>
322 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
323 internal::outer_product_selector_run(dst, lhs, rhs, add(), is_row_major<Dst>());
324 }
325
326 template <typename Dst>
327 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
328 internal::outer_product_selector_run(dst, lhs, rhs, sub(), is_row_major<Dst>());
329 }
330
331 template <typename Dst>
332 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs,
333 const Scalar& alpha) {
334 internal::outer_product_selector_run(dst, lhs, rhs, adds(alpha), is_row_major<Dst>());
335 }
336};
337
338// This base class provides default implementations for evalTo, addTo, subTo, in terms of scaleAndAddTo
339template <typename Lhs, typename Rhs, typename Derived>
340struct generic_product_impl_base {
341 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
342
343 template <typename Dst>
344 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
345 dst.setZero();
346 scaleAndAddTo(dst, lhs, rhs, Scalar(1));
347 }
348
349 template <typename Dst>
350 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
351 scaleAndAddTo(dst, lhs, rhs, Scalar(1));
352 }
353
354 template <typename Dst>
355 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
356 scaleAndAddTo(dst, lhs, rhs, Scalar(-1));
357 }
358
359 template <typename Dst>
360 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs,
361 const Scalar& alpha) {
362 Derived::scaleAndAddTo(dst, lhs, rhs, alpha);
363 }
364};
365
366template <typename Lhs, typename Rhs>
367struct generic_product_impl<Lhs, Rhs, DenseShape, DenseShape, GemvProduct>
368 : generic_product_impl_base<Lhs, Rhs, generic_product_impl<Lhs, Rhs, DenseShape, DenseShape, GemvProduct>> {
369 typedef typename nested_eval<Lhs, 1>::type LhsNested;
370 typedef typename nested_eval<Rhs, 1>::type RhsNested;
371 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
372 enum { Side = Lhs::IsVectorAtCompileTime ? OnTheLeft : OnTheRight };
373 typedef internal::remove_all_t<std::conditional_t<int(Side) == OnTheRight, LhsNested, RhsNested>> MatrixType;
374
375 template <typename Dest>
376 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs,
377 const Scalar& alpha) {
378 // Fallback to inner product if both the lhs and rhs is a runtime vector.
379 if (lhs.rows() == 1 && rhs.cols() == 1) {
380 dst.coeffRef(0, 0) += alpha * lhs.row(0).conjugate().dot(rhs.col(0));
381 return;
382 }
383 LhsNested actual_lhs(lhs);
384 RhsNested actual_rhs(rhs);
385 internal::gemv_dense_selector<Side, (int(MatrixType::Flags) & RowMajorBit) ? RowMajor : ColMajor,
386 bool(internal::blas_traits<MatrixType>::HasUsableDirectAccess)>::run(actual_lhs,
387 actual_rhs, dst,
388 alpha);
389 }
390};
391
392template <typename Lhs, typename Rhs>
393struct generic_product_impl<Lhs, Rhs, DenseShape, DenseShape, CoeffBasedProductMode> {
394 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
395
396 template <typename Dst>
397 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
398 // Same as: dst.noalias() = lhs.lazyProduct(rhs);
399 // but easier on the compiler side
400 call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::assign_op<typename Dst::Scalar, Scalar>());
401 }
402
403 template <typename Dst>
404 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
405 // dst.noalias() += lhs.lazyProduct(rhs);
406 call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::add_assign_op<typename Dst::Scalar, Scalar>());
407 }
408
409 template <typename Dst>
410 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
411 // dst.noalias() -= lhs.lazyProduct(rhs);
412 call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::sub_assign_op<typename Dst::Scalar, Scalar>());
413 }
414
415 // This is a special evaluation path called from generic_product_impl<...,GemmProduct> in file GeneralMatrixMatrix.h
416 // This variant tries to extract scalar multiples from both the LHS and RHS and factor them out. For instance:
417 // dst {,+,-}= (s1*A)*(B*s2)
418 // will be rewritten as:
419 // dst {,+,-}= (s1*s2) * (A.lazyProduct(B))
420 // There are at least four benefits of doing so:
421 // 1 - huge performance gain for heap-allocated matrix types as it save costly allocations.
422 // 2 - it is faster than simply by-passing the heap allocation through stack allocation.
423 // 3 - it makes this fallback consistent with the heavy GEMM routine.
424 // 4 - it fully by-passes huge stack allocation attempts when multiplying huge fixed-size matrices.
425 // (see https://stackoverflow.com/questions/54738495)
426 // For small fixed sizes matrices, however, the gains are less obvious, it is sometimes x2 faster, but sometimes x3
427 // slower, and the behavior depends also a lot on the compiler... This is why this re-writing strategy is currently
428 // enabled only when falling back from the main GEMM.
429 template <typename Dst, typename Func>
430 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void eval_dynamic(Dst& dst, const Lhs& lhs, const Rhs& rhs,
431 const Func& func) {
432 enum {
433 HasScalarFactor = blas_traits<Lhs>::HasScalarFactor || blas_traits<Rhs>::HasScalarFactor,
434 ConjLhs = blas_traits<Lhs>::NeedToConjugate,
435 ConjRhs = blas_traits<Rhs>::NeedToConjugate
436 };
437 // FIXME: in c++11 this should be auto, and extractScalarFactor should also return auto
438 // this is important for real*complex_mat
439 Scalar actualAlpha = combine_scalar_factors<Scalar>(lhs, rhs);
440
441 eval_dynamic_impl(dst, blas_traits<Lhs>::extract(lhs).template conjugateIf<ConjLhs>(),
442 blas_traits<Rhs>::extract(rhs).template conjugateIf<ConjRhs>(), func, actualAlpha,
443 std::conditional_t<HasScalarFactor, true_type, false_type>());
444 }
445
446 protected:
447 template <typename Dst, typename LhsT, typename RhsT, typename Func, typename Scalar>
448 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void eval_dynamic_impl(Dst& dst, const LhsT& lhs, const RhsT& rhs,
449 const Func& func, const Scalar& s /* == 1 */,
450 false_type) {
451 EIGEN_UNUSED_VARIABLE(s);
452 eigen_internal_assert(numext::is_exactly_one(s));
453 call_restricted_packet_assignment_no_alias(dst, lhs.lazyProduct(rhs), func);
454 }
455
456 template <typename Dst, typename LhsT, typename RhsT, typename Func, typename Scalar>
457 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void eval_dynamic_impl(Dst& dst, const LhsT& lhs, const RhsT& rhs,
458 const Func& func, const Scalar& s, true_type) {
459 call_restricted_packet_assignment_no_alias(dst, s * lhs.lazyProduct(rhs), func);
460 }
461};
462
463// This specialization enforces the use of a coefficient-based evaluation strategy
464template <typename Lhs, typename Rhs>
465struct generic_product_impl<Lhs, Rhs, DenseShape, DenseShape, LazyCoeffBasedProductMode>
466 : generic_product_impl<Lhs, Rhs, DenseShape, DenseShape, CoeffBasedProductMode> {};
467
468// Case 2: Evaluate coeff by coeff
469//
470// This is mostly taken from CoeffBasedProduct.h
471// The main difference is that we add an extra argument to the etor_product_*_impl::run() function
472// for the inner dimension of the product, because evaluator object do not know their size.
473
474template <int Traversal, int UnrollingIndex, typename Lhs, typename Rhs, typename RetScalar>
475struct etor_product_coeff_impl;
476
477template <int StorageOrder, int UnrollingIndex, typename Lhs, typename Rhs, typename Packet, int LoadMode>
478struct etor_product_packet_impl;
479
480template <typename Lhs, typename Rhs, int ProductTag>
481struct product_evaluator<Product<Lhs, Rhs, LazyProduct>, ProductTag, DenseShape, DenseShape>
482 : evaluator_base<Product<Lhs, Rhs, LazyProduct>> {
483 typedef Product<Lhs, Rhs, LazyProduct> XprType;
484 typedef typename XprType::Scalar Scalar;
485 typedef typename XprType::CoeffReturnType CoeffReturnType;
486
487 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit product_evaluator(const XprType& xpr)
488 : m_lhs(xpr.lhs()),
489 m_rhs(xpr.rhs()),
490 m_lhsImpl(m_lhs), // FIXME the creation of the evaluator objects should result in a no-op, but check that!
491 m_rhsImpl(m_rhs), // Moreover, they are only useful for the packet path, so we could completely disable
492 // them when not needed, or perhaps declare them on the fly on the packet method... We
493 // have experiment to check what's best.
494 m_innerDim(xpr.lhs().cols()) {
495 EIGEN_INTERNAL_CHECK_COST_VALUE(NumTraits<Scalar>::MulCost);
496 EIGEN_INTERNAL_CHECK_COST_VALUE(NumTraits<Scalar>::AddCost);
497 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
498#if 0
499 std::cerr << "LhsOuterStrideBytes= " << LhsOuterStrideBytes << "\n";
500 std::cerr << "RhsOuterStrideBytes= " << RhsOuterStrideBytes << "\n";
501 std::cerr << "LhsAlignment= " << LhsAlignment << "\n";
502 std::cerr << "RhsAlignment= " << RhsAlignment << "\n";
503 std::cerr << "CanVectorizeLhs= " << CanVectorizeLhs << "\n";
504 std::cerr << "CanVectorizeRhs= " << CanVectorizeRhs << "\n";
505 std::cerr << "CanVectorizeInner= " << CanVectorizeInner << "\n";
506 std::cerr << "EvalToRowMajor= " << EvalToRowMajor << "\n";
507 std::cerr << "Alignment= " << Alignment << "\n";
508 std::cerr << "Flags= " << Flags << "\n";
509#endif
510 }
511
512 // Everything below here is taken from CoeffBasedProduct.h
513
514 typedef typename internal::nested_eval<Lhs, Rhs::ColsAtCompileTime>::type LhsNested;
515 typedef typename internal::nested_eval<Rhs, Lhs::RowsAtCompileTime>::type RhsNested;
516
517 typedef internal::remove_all_t<LhsNested> LhsNestedCleaned;
518 typedef internal::remove_all_t<RhsNested> RhsNestedCleaned;
519
520 typedef evaluator<LhsNestedCleaned> LhsEtorType;
521 typedef evaluator<RhsNestedCleaned> RhsEtorType;
522
523 enum {
524 RowsAtCompileTime = LhsNestedCleaned::RowsAtCompileTime,
525 ColsAtCompileTime = RhsNestedCleaned::ColsAtCompileTime,
526 InnerSize = min_size_prefer_fixed(LhsNestedCleaned::ColsAtCompileTime, RhsNestedCleaned::RowsAtCompileTime),
527 MaxRowsAtCompileTime = LhsNestedCleaned::MaxRowsAtCompileTime,
528 MaxColsAtCompileTime = RhsNestedCleaned::MaxColsAtCompileTime
529 };
530
531 typedef typename find_best_packet<Scalar, RowsAtCompileTime>::type LhsVecPacketType;
532 typedef typename find_best_packet<Scalar, ColsAtCompileTime>::type RhsVecPacketType;
533
534 enum {
535
536 LhsCoeffReadCost = LhsEtorType::CoeffReadCost,
537 RhsCoeffReadCost = RhsEtorType::CoeffReadCost,
538 CoeffReadCost = InnerSize == 0 ? NumTraits<Scalar>::ReadCost
539 : InnerSize == Dynamic
540 ? HugeCost
541 : InnerSize * (NumTraits<Scalar>::MulCost + int(LhsCoeffReadCost) + int(RhsCoeffReadCost)) +
542 (InnerSize - 1) * NumTraits<Scalar>::AddCost,
543
544 Unroll = CoeffReadCost <= EIGEN_UNROLLING_LIMIT,
545
546 LhsFlags = LhsEtorType::Flags,
547 RhsFlags = RhsEtorType::Flags,
548
549 LhsRowMajor = LhsFlags & RowMajorBit,
550 RhsRowMajor = RhsFlags & RowMajorBit,
551
552 LhsVecPacketSize = unpacket_traits<LhsVecPacketType>::size,
553 RhsVecPacketSize = unpacket_traits<RhsVecPacketType>::size,
554
555 // Here, we don't care about alignment larger than the usable packet size.
556 LhsAlignment =
557 plain_enum_min(LhsEtorType::Alignment, LhsVecPacketSize* int(sizeof(typename LhsNestedCleaned::Scalar))),
558 RhsAlignment =
559 plain_enum_min(RhsEtorType::Alignment, RhsVecPacketSize* int(sizeof(typename RhsNestedCleaned::Scalar))),
560
561 SameType = is_same<typename LhsNestedCleaned::Scalar, typename RhsNestedCleaned::Scalar>::value,
562
563 CanVectorizeRhs = bool(RhsRowMajor) && (RhsFlags & PacketAccessBit) && (ColsAtCompileTime != 1),
564 CanVectorizeLhs = (!LhsRowMajor) && (LhsFlags & PacketAccessBit) && (RowsAtCompileTime != 1),
565
566 EvalToRowMajor = (MaxRowsAtCompileTime == 1 && MaxColsAtCompileTime != 1) ? 1
567 : (MaxColsAtCompileTime == 1 && MaxRowsAtCompileTime != 1)
568 ? 0
569 : (bool(RhsRowMajor) && !CanVectorizeLhs),
570
571 Flags = ((int(LhsFlags) | int(RhsFlags)) & HereditaryBits & ~RowMajorBit) |
572 (EvalToRowMajor ? RowMajorBit : 0)
573 // TODO enable vectorization for mixed types
574 | (SameType && (CanVectorizeLhs || CanVectorizeRhs) ? PacketAccessBit : 0) |
575 (XprType::IsVectorAtCompileTime ? LinearAccessBit : 0),
576
577 LhsOuterStrideBytes =
578 int(LhsNestedCleaned::OuterStrideAtCompileTime) * int(sizeof(typename LhsNestedCleaned::Scalar)),
579 RhsOuterStrideBytes =
580 int(RhsNestedCleaned::OuterStrideAtCompileTime) * int(sizeof(typename RhsNestedCleaned::Scalar)),
581
582 Alignment = bool(CanVectorizeLhs)
583 ? (LhsOuterStrideBytes <= 0 || (int(LhsOuterStrideBytes) % plain_enum_max(1, LhsAlignment)) != 0
584 ? 0
585 : LhsAlignment)
586 : bool(CanVectorizeRhs)
587 ? (RhsOuterStrideBytes <= 0 || (int(RhsOuterStrideBytes) % plain_enum_max(1, RhsAlignment)) != 0
588 ? 0
589 : RhsAlignment)
590 : 0,
591
592 /* CanVectorizeInner deserves special explanation. It does not affect the product flags. It is not used outside
593 * of Product. If the Product itself is not a packet-access expression, there is still a chance that the inner
594 * loop of the product might be vectorized. This is the meaning of CanVectorizeInner. Since it doesn't affect
595 * the Flags, it is safe to make this value depend on ActualPacketAccessBit, that doesn't affect the ABI.
596 */
597 CanVectorizeInner = SameType && LhsRowMajor && (!RhsRowMajor) &&
598 (int(LhsFlags) & int(RhsFlags) & ActualPacketAccessBit) &&
599 (int(InnerSize) % packet_traits<Scalar>::size == 0)
600 };
601
602 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CoeffReturnType coeff(Index row, Index col) const {
603 return (m_lhs.row(row).transpose().cwiseProduct(m_rhs.col(col))).sum();
604 }
605
606 /* Allow index-based non-packet access. It is impossible though to allow index-based packed access,
607 * which is why we don't set the LinearAccessBit.
608 * TODO: this seems possible when the result is a vector
609 */
610 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CoeffReturnType coeff(Index index) const {
611 const Index row = (RowsAtCompileTime == 1 || MaxRowsAtCompileTime == 1) ? 0 : index;
612 const Index col = (RowsAtCompileTime == 1 || MaxRowsAtCompileTime == 1) ? index : 0;
613 return (m_lhs.row(row).transpose().cwiseProduct(m_rhs.col(col))).sum();
614 }
615
616 template <int LoadMode, typename PacketType>
617 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const PacketType packet(Index row, Index col) const {
618 PacketType res;
619 typedef etor_product_packet_impl<bool(int(Flags) & RowMajorBit) ? RowMajor : ColMajor,
620 Unroll ? int(InnerSize) : Dynamic, LhsEtorType, RhsEtorType, PacketType, LoadMode>
621 PacketImpl;
622 PacketImpl::run(row, col, m_lhsImpl, m_rhsImpl, m_innerDim, res);
623 return res;
624 }
625
626 template <int LoadMode, typename PacketType>
627 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const PacketType packet(Index index) const {
628 const Index row = (RowsAtCompileTime == 1 || MaxRowsAtCompileTime == 1) ? 0 : index;
629 const Index col = (RowsAtCompileTime == 1 || MaxRowsAtCompileTime == 1) ? index : 0;
630 return packet<LoadMode, PacketType>(row, col);
631 }
632
633 protected:
634 add_const_on_value_type_t<LhsNested> m_lhs;
635 add_const_on_value_type_t<RhsNested> m_rhs;
636
637 LhsEtorType m_lhsImpl;
638 RhsEtorType m_rhsImpl;
639
640 // TODO: Get rid of m_innerDim if known at compile time
641 Index m_innerDim;
642};
643
644template <typename Lhs, typename Rhs>
645struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, LazyCoeffBasedProductMode, DenseShape, DenseShape>
646 : product_evaluator<Product<Lhs, Rhs, LazyProduct>, CoeffBasedProductMode, DenseShape, DenseShape> {
647 typedef Product<Lhs, Rhs, DefaultProduct> XprType;
648 typedef Product<Lhs, Rhs, LazyProduct> BaseProduct;
649 typedef product_evaluator<BaseProduct, CoeffBasedProductMode, DenseShape, DenseShape> Base;
650 enum { Flags = Base::Flags | EvalBeforeNestingBit };
651 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit product_evaluator(const XprType& xpr)
652 : Base(BaseProduct(xpr.lhs(), xpr.rhs())) {}
653};
654
655/****************************************
656*** Coeff based product, Packet path ***
657****************************************/
658
659template <int UnrollingIndex, typename Lhs, typename Rhs, typename Packet, int LoadMode>
660struct etor_product_packet_impl<RowMajor, UnrollingIndex, Lhs, Rhs, Packet, LoadMode> {
661 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs,
662 Index innerDim, Packet& res) {
663 etor_product_packet_impl<RowMajor, UnrollingIndex - 1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs,
664 innerDim, res);
665 res = pmadd(pset1<Packet>(lhs.coeff(row, Index(UnrollingIndex - 1))),
666 rhs.template packet<LoadMode, Packet>(Index(UnrollingIndex - 1), col), res);
667 }
668};
669
670template <int UnrollingIndex, typename Lhs, typename Rhs, typename Packet, int LoadMode>
671struct etor_product_packet_impl<ColMajor, UnrollingIndex, Lhs, Rhs, Packet, LoadMode> {
672 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs,
673 Index innerDim, Packet& res) {
674 etor_product_packet_impl<ColMajor, UnrollingIndex - 1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs,
675 innerDim, res);
676 res = pmadd(lhs.template packet<LoadMode, Packet>(row, Index(UnrollingIndex - 1)),
677 pset1<Packet>(rhs.coeff(Index(UnrollingIndex - 1), col)), res);
678 }
679};
680
681template <typename Lhs, typename Rhs, typename Packet, int LoadMode>
682struct etor_product_packet_impl<RowMajor, 1, Lhs, Rhs, Packet, LoadMode> {
683 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs,
684 Index /*innerDim*/, Packet& res) {
685 res = pmul(pset1<Packet>(lhs.coeff(row, Index(0))), rhs.template packet<LoadMode, Packet>(Index(0), col));
686 }
687};
688
689template <typename Lhs, typename Rhs, typename Packet, int LoadMode>
690struct etor_product_packet_impl<ColMajor, 1, Lhs, Rhs, Packet, LoadMode> {
691 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs,
692 Index /*innerDim*/, Packet& res) {
693 res = pmul(lhs.template packet<LoadMode, Packet>(row, Index(0)), pset1<Packet>(rhs.coeff(Index(0), col)));
694 }
695};
696
697template <typename Lhs, typename Rhs, typename Packet, int LoadMode>
698struct etor_product_packet_impl<RowMajor, 0, Lhs, Rhs, Packet, LoadMode> {
699 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index /*row*/, Index /*col*/, const Lhs& /*lhs*/,
700 const Rhs& /*rhs*/, Index /*innerDim*/, Packet& res) {
701 res = pset1<Packet>(typename unpacket_traits<Packet>::type(0));
702 }
703};
704
705template <typename Lhs, typename Rhs, typename Packet, int LoadMode>
706struct etor_product_packet_impl<ColMajor, 0, Lhs, Rhs, Packet, LoadMode> {
707 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index /*row*/, Index /*col*/, const Lhs& /*lhs*/,
708 const Rhs& /*rhs*/, Index /*innerDim*/, Packet& res) {
709 res = pset1<Packet>(typename unpacket_traits<Packet>::type(0));
710 }
711};
712
713template <typename Lhs, typename Rhs, typename Packet, int LoadMode>
714struct etor_product_packet_impl<RowMajor, Dynamic, Lhs, Rhs, Packet, LoadMode> {
715 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs,
716 Index innerDim, Packet& res) {
717 res = pset1<Packet>(typename unpacket_traits<Packet>::type(0));
718 for (Index i = 0; i < innerDim; ++i)
719 res = pmadd(pset1<Packet>(lhs.coeff(row, i)), rhs.template packet<LoadMode, Packet>(i, col), res);
720 }
721};
722
723template <typename Lhs, typename Rhs, typename Packet, int LoadMode>
724struct etor_product_packet_impl<ColMajor, Dynamic, Lhs, Rhs, Packet, LoadMode> {
725 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs,
726 Index innerDim, Packet& res) {
727 res = pset1<Packet>(typename unpacket_traits<Packet>::type(0));
728 for (Index i = 0; i < innerDim; ++i)
729 res = pmadd(lhs.template packet<LoadMode, Packet>(row, i), pset1<Packet>(rhs.coeff(i, col)), res);
730 }
731};
732
733/***************************************************************************
734 * Triangular products
735 ***************************************************************************/
736template <int Mode, bool LhsIsTriangular, typename Lhs, bool LhsIsVector, typename Rhs, bool RhsIsVector>
737struct triangular_product_impl;
738
739template <typename Lhs, typename Rhs, int ProductTag>
740struct generic_product_impl<Lhs, Rhs, TriangularShape, DenseShape, ProductTag>
741 : generic_product_impl_base<Lhs, Rhs, generic_product_impl<Lhs, Rhs, TriangularShape, DenseShape, ProductTag>> {
742 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
743
744 template <typename Dest>
745 static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
746 triangular_product_impl<Lhs::Mode, true, typename Lhs::MatrixType, false, Rhs, Rhs::ColsAtCompileTime == 1>::run(
747 dst, lhs.nestedExpression(), rhs, alpha);
748 }
749};
750
751template <typename Lhs, typename Rhs, int ProductTag>
752struct generic_product_impl<Lhs, Rhs, DenseShape, TriangularShape, ProductTag>
753 : generic_product_impl_base<Lhs, Rhs, generic_product_impl<Lhs, Rhs, DenseShape, TriangularShape, ProductTag>> {
754 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
755
756 template <typename Dest>
757 static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
758 triangular_product_impl<Rhs::Mode, false, Lhs, Lhs::RowsAtCompileTime == 1, typename Rhs::MatrixType, false>::run(
759 dst, lhs, rhs.nestedExpression(), alpha);
760 }
761};
762
763/***************************************************************************
764 * SelfAdjoint products
765 ***************************************************************************/
766template <typename Lhs, int LhsMode, bool LhsIsVector, typename Rhs, int RhsMode, bool RhsIsVector>
767struct selfadjoint_product_impl;
768
769template <typename Lhs, typename Rhs, int ProductTag>
770struct generic_product_impl<Lhs, Rhs, SelfAdjointShape, DenseShape, ProductTag>
771 : generic_product_impl_base<Lhs, Rhs, generic_product_impl<Lhs, Rhs, SelfAdjointShape, DenseShape, ProductTag>> {
772 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
773
774 template <typename Dest>
775 static EIGEN_DEVICE_FUNC void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
776 selfadjoint_product_impl<typename Lhs::MatrixType, Lhs::Mode, false, Rhs, 0, Rhs::IsVectorAtCompileTime>::run(
777 dst, lhs.nestedExpression(), rhs, alpha);
778 }
779};
780
781template <typename Lhs, typename Rhs, int ProductTag>
782struct generic_product_impl<Lhs, Rhs, DenseShape, SelfAdjointShape, ProductTag>
783 : generic_product_impl_base<Lhs, Rhs, generic_product_impl<Lhs, Rhs, DenseShape, SelfAdjointShape, ProductTag>> {
784 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
785
786 template <typename Dest>
787 static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
788 selfadjoint_product_impl<Lhs, 0, Lhs::IsVectorAtCompileTime, typename Rhs::MatrixType, Rhs::Mode, false>::run(
789 dst, lhs, rhs.nestedExpression(), alpha);
790 }
791};
792
793/***************************************************************************
794 * Diagonal products
795 ***************************************************************************/
796
797template <typename MatrixType, typename DiagonalType, typename Derived, int ProductOrder>
798struct diagonal_product_evaluator_base : evaluator_base<Derived> {
799 typedef typename ScalarBinaryOpTraits<typename MatrixType::Scalar, typename DiagonalType::Scalar>::ReturnType Scalar;
800
801 public:
802 enum {
803 CoeffReadCost = int(NumTraits<Scalar>::MulCost) + int(evaluator<MatrixType>::CoeffReadCost) +
804 int(evaluator<DiagonalType>::CoeffReadCost),
805
806 MatrixFlags = evaluator<MatrixType>::Flags,
807 DiagFlags = evaluator<DiagonalType>::Flags,
808
809 StorageOrder_ = (Derived::MaxRowsAtCompileTime == 1 && Derived::MaxColsAtCompileTime != 1) ? RowMajor
810 : (Derived::MaxColsAtCompileTime == 1 && Derived::MaxRowsAtCompileTime != 1) ? ColMajor
811 : MatrixFlags & RowMajorBit ? RowMajor
812 : ColMajor,
813 SameStorageOrder_ = StorageOrder_ == (MatrixFlags & RowMajorBit ? RowMajor : ColMajor),
814
815 ScalarAccessOnDiag_ = !((int(StorageOrder_) == ColMajor && int(ProductOrder) == OnTheLeft) ||
816 (int(StorageOrder_) == RowMajor && int(ProductOrder) == OnTheRight)),
817 SameTypes_ = is_same<typename MatrixType::Scalar, typename DiagonalType::Scalar>::value,
818 // FIXME currently we need same types, but in the future the next rule should be the one
819 // Vectorizable_ = bool(int(MatrixFlags)&PacketAccessBit) && ((!_PacketOnDiag) || (SameTypes_ &&
820 // bool(int(DiagFlags)&PacketAccessBit))),
821 Vectorizable_ = bool(int(MatrixFlags) & PacketAccessBit) && SameTypes_ &&
822 (SameStorageOrder_ || (MatrixFlags & LinearAccessBit) == LinearAccessBit) &&
823 (ScalarAccessOnDiag_ || (bool(int(DiagFlags) & PacketAccessBit))),
824 LinearAccessMask_ =
825 (MatrixType::RowsAtCompileTime == 1 || MatrixType::ColsAtCompileTime == 1) ? LinearAccessBit : 0,
826 Flags =
827 ((HereditaryBits | LinearAccessMask_) & (unsigned int)(MatrixFlags)) | (Vectorizable_ ? PacketAccessBit : 0),
828 Alignment = evaluator<MatrixType>::Alignment,
829
830 AsScalarProduct =
831 (DiagonalType::SizeAtCompileTime == 1) ||
832 (DiagonalType::SizeAtCompileTime == Dynamic && MatrixType::RowsAtCompileTime == 1 &&
833 ProductOrder == OnTheLeft) ||
834 (DiagonalType::SizeAtCompileTime == Dynamic && MatrixType::ColsAtCompileTime == 1 && ProductOrder == OnTheRight)
835 };
836
837 EIGEN_DEVICE_FUNC diagonal_product_evaluator_base(const MatrixType& mat, const DiagonalType& diag)
838 : m_diagImpl(diag), m_matImpl(mat) {
839 EIGEN_INTERNAL_CHECK_COST_VALUE(NumTraits<Scalar>::MulCost);
840 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
841 }
842
843 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index idx) const {
844 if (AsScalarProduct)
845 return m_diagImpl.coeff(0) * m_matImpl.coeff(idx);
846 else
847 return m_diagImpl.coeff(idx) * m_matImpl.coeff(idx);
848 }
849
850 protected:
851 template <int LoadMode, typename PacketType>
852 EIGEN_STRONG_INLINE PacketType packet_impl(Index row, Index col, Index id, internal::true_type) const {
853 return internal::pmul(m_matImpl.template packet<LoadMode, PacketType>(row, col),
854 internal::pset1<PacketType>(m_diagImpl.coeff(id)));
855 }
856
857 template <int LoadMode, typename PacketType>
858 EIGEN_STRONG_INLINE PacketType packet_impl(Index row, Index col, Index id, internal::false_type) const {
859 enum {
860 InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime,
861 DiagonalPacketLoadMode = plain_enum_min(
862 LoadMode,
863 ((InnerSize % 16) == 0) ? int(Aligned16) : int(evaluator<DiagonalType>::Alignment)) // FIXME hardcoded 16!!
864 };
865 return internal::pmul(m_matImpl.template packet<LoadMode, PacketType>(row, col),
866 m_diagImpl.template packet<DiagonalPacketLoadMode, PacketType>(id));
867 }
868
869 evaluator<DiagonalType> m_diagImpl;
870 evaluator<MatrixType> m_matImpl;
871};
872
873// diagonal * dense
874template <typename Lhs, typename Rhs, int ProductKind, int ProductTag>
875struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DiagonalShape, DenseShape>
876 : diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct>,
877 OnTheLeft> {
878 typedef diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct>,
879 OnTheLeft>
880 Base;
881 using Base::coeff;
882 using Base::m_diagImpl;
883 using Base::m_matImpl;
884 typedef typename Base::Scalar Scalar;
885
886 typedef Product<Lhs, Rhs, ProductKind> XprType;
887 typedef typename XprType::PlainObject PlainObject;
888 typedef typename Lhs::DiagonalVectorType DiagonalType;
889
890 enum { StorageOrder = Base::StorageOrder_ };
891
892 EIGEN_DEVICE_FUNC explicit product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {}
893
894 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const {
895 return m_diagImpl.coeff(row) * m_matImpl.coeff(row, col);
896 }
897
898#ifndef EIGEN_GPUCC
899 template <int LoadMode, typename PacketType>
900 EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const {
901 // FIXME: NVCC used to complain about the template keyword, but we have to check whether this is still the case.
902 // See also similar calls below.
903 return this->template packet_impl<LoadMode, PacketType>(
904 row, col, row, std::conditional_t<int(StorageOrder) == RowMajor, internal::true_type, internal::false_type>());
905 }
906
907 template <int LoadMode, typename PacketType>
908 EIGEN_STRONG_INLINE PacketType packet(Index idx) const {
909 return packet<LoadMode, PacketType>(int(StorageOrder) == ColMajor ? idx : 0,
910 int(StorageOrder) == ColMajor ? 0 : idx);
911 }
912#endif
913};
914
915// dense * diagonal
916template <typename Lhs, typename Rhs, int ProductKind, int ProductTag>
917struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DenseShape, DiagonalShape>
918 : diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct>,
919 OnTheRight> {
920 typedef diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct>,
922 Base;
923 using Base::coeff;
924 using Base::m_diagImpl;
925 using Base::m_matImpl;
926 typedef typename Base::Scalar Scalar;
927
928 typedef Product<Lhs, Rhs, ProductKind> XprType;
929 typedef typename XprType::PlainObject PlainObject;
930
931 enum { StorageOrder = Base::StorageOrder_ };
932
933 EIGEN_DEVICE_FUNC explicit product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal()) {}
934
935 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const {
936 return m_matImpl.coeff(row, col) * m_diagImpl.coeff(col);
937 }
938
939#ifndef EIGEN_GPUCC
940 template <int LoadMode, typename PacketType>
941 EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const {
942 return this->template packet_impl<LoadMode, PacketType>(
943 row, col, col, std::conditional_t<int(StorageOrder) == ColMajor, internal::true_type, internal::false_type>());
944 }
945
946 template <int LoadMode, typename PacketType>
947 EIGEN_STRONG_INLINE PacketType packet(Index idx) const {
948 return packet<LoadMode, PacketType>(int(StorageOrder) == ColMajor ? idx : 0,
949 int(StorageOrder) == ColMajor ? 0 : idx);
950 }
951#endif
952};
953
954/***************************************************************************
955 * Products with permutation matrices
956 ***************************************************************************/
957
963template <typename ExpressionType, int Side, bool Transposed, typename ExpressionShape>
964struct permutation_matrix_product;
965
966template <typename ExpressionType, int Side, bool Transposed>
967struct permutation_matrix_product<ExpressionType, Side, Transposed, DenseShape> {
968 typedef typename nested_eval<ExpressionType, 1>::type MatrixType;
969 typedef remove_all_t<MatrixType> MatrixTypeCleaned;
970
971 template <typename Dest, typename PermutationType>
972 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Dest& dst, const PermutationType& perm,
973 const ExpressionType& xpr) {
974 MatrixType mat(xpr);
975 const Index n = Side == OnTheLeft ? mat.rows() : mat.cols();
976 // FIXME we need an is_same for expression that is not sensitive to constness. For instance
977 // is_same_xpr<Block<const Matrix>, Block<Matrix> >::value should be true.
978 // if(is_same<MatrixTypeCleaned,Dest>::value && extract_data(dst) == extract_data(mat))
979 if (is_same_dense(dst, mat)) {
980 // apply the permutation inplace
981 Matrix<bool, PermutationType::RowsAtCompileTime, 1, 0, PermutationType::MaxRowsAtCompileTime> mask(perm.size());
982 mask.fill(false);
983 Index r = 0;
984 while (r < perm.size()) {
985 // search for the next seed
986 while (r < perm.size() && mask[r]) r++;
987 if (r >= perm.size()) break;
988 // we got one, let's follow it until we are back to the seed
989 Index k0 = r++;
990 Index kPrev = k0;
991 mask.coeffRef(k0) = true;
992 for (Index k = perm.indices().coeff(k0); k != k0; k = perm.indices().coeff(k)) {
993 Block<Dest, Side == OnTheLeft ? 1 : Dest::RowsAtCompileTime,
994 Side == OnTheRight ? 1 : Dest::ColsAtCompileTime>(dst, k)
995 .swap(Block < Dest, Side == OnTheLeft ? 1 : Dest::RowsAtCompileTime,
996 Side == OnTheRight
997 ? 1
998 : Dest::ColsAtCompileTime > (dst, ((Side == OnTheLeft) ^ Transposed) ? k0 : kPrev));
999
1000 mask.coeffRef(k) = true;
1001 kPrev = k;
1002 }
1003 }
1004 } else {
1005 for (Index i = 0; i < n; ++i) {
1006 Block<Dest, Side == OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side == OnTheRight ? 1 : Dest::ColsAtCompileTime>(
1007 dst, ((Side == OnTheLeft) ^ Transposed) ? perm.indices().coeff(i) : i)
1008
1009 =
1010
1011 Block < const MatrixTypeCleaned,
1012 Side == OnTheLeft ? 1 : MatrixTypeCleaned::RowsAtCompileTime,
1013 Side == OnTheRight ? 1
1014 : MatrixTypeCleaned::ColsAtCompileTime >
1015 (mat, ((Side == OnTheRight) ^ Transposed) ? perm.indices().coeff(i) : i);
1016 }
1017 }
1018 }
1019};
1020
1021template <typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
1022struct generic_product_impl<Lhs, Rhs, PermutationShape, MatrixShape, ProductTag> {
1023 template <typename Dest>
1024 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) {
1025 permutation_matrix_product<Rhs, OnTheLeft, false, MatrixShape>::run(dst, lhs, rhs);
1026 }
1027};
1028
1029template <typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
1030struct generic_product_impl<Lhs, Rhs, MatrixShape, PermutationShape, ProductTag> {
1031 template <typename Dest>
1032 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) {
1033 permutation_matrix_product<Lhs, OnTheRight, false, MatrixShape>::run(dst, rhs, lhs);
1034 }
1035};
1036
1037template <typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
1038struct generic_product_impl<Inverse<Lhs>, Rhs, PermutationShape, MatrixShape, ProductTag> {
1039 template <typename Dest>
1040 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Inverse<Lhs>& lhs, const Rhs& rhs) {
1041 permutation_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedExpression(), rhs);
1042 }
1043};
1044
1045template <typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
1046struct generic_product_impl<Lhs, Inverse<Rhs>, MatrixShape, PermutationShape, ProductTag> {
1047 template <typename Dest>
1048 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Inverse<Rhs>& rhs) {
1049 permutation_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedExpression(), lhs);
1050 }
1051};
1052
1053/***************************************************************************
1054 * Products with transpositions matrices
1055 ***************************************************************************/
1056
1057// FIXME could we unify Transpositions and Permutation into a single "shape"??
1058
1063template <typename ExpressionType, int Side, bool Transposed, typename ExpressionShape>
1064struct transposition_matrix_product {
1065 typedef typename nested_eval<ExpressionType, 1>::type MatrixType;
1066 typedef remove_all_t<MatrixType> MatrixTypeCleaned;
1067
1068 template <typename Dest, typename TranspositionType>
1069 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Dest& dst, const TranspositionType& tr,
1070 const ExpressionType& xpr) {
1071 MatrixType mat(xpr);
1072 typedef typename TranspositionType::StorageIndex StorageIndex;
1073 const Index size = tr.size();
1074 StorageIndex j = 0;
1075
1076 if (!is_same_dense(dst, mat)) dst = mat;
1077
1078 for (Index k = (Transposed ? size - 1 : 0); Transposed ? k >= 0 : k < size; Transposed ? --k : ++k)
1079 if (Index(j = tr.coeff(k)) != k) {
1080 if (Side == OnTheLeft)
1081 dst.row(k).swap(dst.row(j));
1082 else if (Side == OnTheRight)
1083 dst.col(k).swap(dst.col(j));
1084 }
1085 }
1086};
1087
1088template <typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
1089struct generic_product_impl<Lhs, Rhs, TranspositionsShape, MatrixShape, ProductTag> {
1090 template <typename Dest>
1091 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) {
1092 transposition_matrix_product<Rhs, OnTheLeft, false, MatrixShape>::run(dst, lhs, rhs);
1093 }
1094};
1095
1096template <typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
1097struct generic_product_impl<Lhs, Rhs, MatrixShape, TranspositionsShape, ProductTag> {
1098 template <typename Dest>
1099 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) {
1100 transposition_matrix_product<Lhs, OnTheRight, false, MatrixShape>::run(dst, rhs, lhs);
1101 }
1102};
1103
1104template <typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
1105struct generic_product_impl<Transpose<Lhs>, Rhs, TranspositionsShape, MatrixShape, ProductTag> {
1106 template <typename Dest>
1107 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs) {
1108 transposition_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedExpression(), rhs);
1109 }
1110};
1111
1112template <typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
1113struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, TranspositionsShape, ProductTag> {
1114 template <typename Dest>
1115 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs) {
1116 transposition_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedExpression(), lhs);
1117 }
1118};
1119
1120/***************************************************************************
1121 * skew symmetric products
1122 * for now we just call the generic implementation
1123 ***************************************************************************/
1124template <typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
1125struct generic_product_impl<Lhs, Rhs, SkewSymmetricShape, MatrixShape, ProductTag> {
1126 template <typename Dest>
1127 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) {
1128 generic_product_impl<typename Lhs::DenseMatrixType, Rhs, DenseShape, MatrixShape, ProductTag>::evalTo(dst, lhs,
1129 rhs);
1130 }
1131};
1132
1133template <typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
1134struct generic_product_impl<Lhs, Rhs, MatrixShape, SkewSymmetricShape, ProductTag> {
1135 template <typename Dest>
1136 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) {
1137 generic_product_impl<Lhs, typename Rhs::DenseMatrixType, MatrixShape, DenseShape, ProductTag>::evalTo(dst, lhs,
1138 rhs);
1139 }
1140};
1141
1142template <typename Lhs, typename Rhs, int ProductTag>
1143struct generic_product_impl<Lhs, Rhs, SkewSymmetricShape, SkewSymmetricShape, ProductTag> {
1144 template <typename Dest>
1145 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) {
1146 generic_product_impl<typename Lhs::DenseMatrixType, typename Rhs::DenseMatrixType, DenseShape, DenseShape,
1147 ProductTag>::evalTo(dst, lhs, rhs);
1148 }
1149};
1150
1151} // end namespace internal
1152
1153} // end namespace Eigen
1154
1155#endif // EIGEN_PRODUCT_EVALUATORS_H
@ Aligned16
Definition Constants.h:237
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
@ OnTheLeft
Definition Constants.h:331
@ OnTheRight
Definition Constants.h:333
const unsigned int ActualPacketAccessBit
Definition Constants.h:108
const unsigned int PacketAccessBit
Definition Constants.h:97
const unsigned int LinearAccessBit
Definition Constants.h:133
const unsigned int EvalBeforeNestingBit
Definition Constants.h:74
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition Core:137
const int HugeCost
Definition Constants.h:48
const int Dynamic
Definition Constants.h:25