162struct product_type_selector<Small, Small, Large> {
163 enum { ret = GemmProduct };
166struct product_type_selector<Large, Small, Large> {
167 enum { ret = GemmProduct };
170struct product_type_selector<Small, Large, Large> {
171 enum { ret = GemmProduct };
174struct product_type_selector<Large, Large, Large> {
175 enum { ret = GemmProduct };
178struct product_type_selector<Large, Small, Small> {
179 enum { ret = CoeffBasedProductMode };
182struct product_type_selector<Small, Large, Small> {
183 enum { ret = CoeffBasedProductMode };
186struct product_type_selector<Large, Large, Small> {
187 enum { ret = GemmProduct };
220template <
int S
ide,
int StorageOrder,
bool BlasCompatible>
221struct gemv_dense_selector;
227template <
typename Scalar,
int Size,
int MaxSize,
bool Cond>
228struct gemv_static_vector_if;
230template <
typename Scalar,
int Size,
int MaxSize>
231struct gemv_static_vector_if<Scalar, Size, MaxSize, false> {
232 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Scalar* data() {
233 eigen_internal_assert(
false &&
"should never be called");
238template <
typename Scalar,
int Size>
239struct gemv_static_vector_if<Scalar, Size,
Dynamic, true> {
240 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Scalar* data() {
return 0; }
243template <
typename Scalar,
int Size,
int MaxSize>
244struct gemv_static_vector_if<Scalar, Size, MaxSize, true> {
245#if EIGEN_MAX_STATIC_ALIGN_BYTES != 0
246 internal::plain_array<Scalar, internal::min_size_prefer_fixed(Size, MaxSize), 0, AlignedMax> m_data;
247 EIGEN_STRONG_INLINE Scalar* data() {
return m_data.array; }
251 internal::plain_array<Scalar, internal::min_size_prefer_fixed(Size, MaxSize) + EIGEN_MAX_ALIGN_BYTES, 0> m_data;
252 EIGEN_STRONG_INLINE Scalar* data() {
253 return reinterpret_cast<Scalar*
>((std::uintptr_t(m_data.array) & ~(std::size_t(EIGEN_MAX_ALIGN_BYTES - 1))) +
254 EIGEN_MAX_ALIGN_BYTES);
260template <
int StorageOrder,
bool BlasCompatible>
261struct gemv_dense_selector<
OnTheLeft, StorageOrder, BlasCompatible> {
262 template <
typename Lhs,
typename Rhs,
typename Dest>
263 static void run(
const Lhs& lhs,
const Rhs& rhs, Dest& dest,
const typename Dest::Scalar& alpha) {
264 Transpose<Dest> destT(dest);
266 gemv_dense_selector<OnTheRight, OtherStorageOrder, BlasCompatible>::run(rhs.transpose(), lhs.transpose(), destT,
273 template <
typename Lhs,
typename Rhs,
typename Dest>
274 static inline void run(
const Lhs& lhs,
const Rhs& rhs, Dest& dest,
const typename Dest::Scalar& alpha) {
275 typedef typename Lhs::Scalar LhsScalar;
276 typedef typename Rhs::Scalar RhsScalar;
277 typedef typename Dest::Scalar ResScalar;
279 typedef internal::blas_traits<Lhs> LhsBlasTraits;
280 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
281 typedef internal::blas_traits<Rhs> RhsBlasTraits;
282 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
284 typedef Map<Matrix<ResScalar, Dynamic, 1>, plain_enum_min(AlignedMax, internal::packet_traits<ResScalar>::size)>
287 ActualLhsType actualLhs = LhsBlasTraits::extract(lhs);
288 ActualRhsType actualRhs = RhsBlasTraits::extract(rhs);
290 ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
293 typedef std::conditional_t<Dest::IsVectorAtCompileTime, Dest, typename Dest::ColXpr> ActualDest;
298 EvalToDestAtCompileTime = (ActualDest::InnerStrideAtCompileTime == 1),
299 ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
300 MightCannotUseDest = ((!EvalToDestAtCompileTime) || ComplexByReal) && (ActualDest::MaxSizeAtCompileTime != 0)
303 typedef const_blas_data_mapper<LhsScalar, Index, ColMajor> LhsMapper;
304 typedef const_blas_data_mapper<RhsScalar, Index, RowMajor> RhsMapper;
305 RhsScalar compatibleAlpha = get_factor<ResScalar, RhsScalar>::run(actualAlpha);
307 if (!MightCannotUseDest) {
310 general_matrix_vector_product<Index, LhsScalar, LhsMapper,
ColMajor, LhsBlasTraits::NeedToConjugate, RhsScalar,
311 RhsMapper, RhsBlasTraits::NeedToConjugate>::run(actualLhs.rows(), actualLhs.cols(),
312 LhsMapper(actualLhs.data(),
313 actualLhs.outerStride()),
314 RhsMapper(actualRhs.data(),
315 actualRhs.innerStride()),
316 dest.data(), 1, compatibleAlpha);
318 gemv_static_vector_if<ResScalar, ActualDest::SizeAtCompileTime, ActualDest::MaxSizeAtCompileTime,
322 const bool alphaIsCompatible = (!ComplexByReal) || (numext::is_exactly_zero(numext::imag(actualAlpha)));
323 const bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
325 ei_declare_aligned_stack_constructed_variable(ResScalar, actualDestPtr, dest.size(),
326 evalToDest ? dest.data() : static_dest.data());
329#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
330 Index size = dest.size();
331 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
333 if (!alphaIsCompatible) {
334 MappedDest(actualDestPtr, dest.size()).setZero();
335 compatibleAlpha = RhsScalar(1);
337 MappedDest(actualDestPtr, dest.size()) = dest;
340 general_matrix_vector_product<Index, LhsScalar, LhsMapper,
ColMajor, LhsBlasTraits::NeedToConjugate, RhsScalar,
341 RhsMapper, RhsBlasTraits::NeedToConjugate>::run(actualLhs.rows(), actualLhs.cols(),
342 LhsMapper(actualLhs.data(),
343 actualLhs.outerStride()),
344 RhsMapper(actualRhs.data(),
345 actualRhs.innerStride()),
346 actualDestPtr, 1, compatibleAlpha);
349 if (!alphaIsCompatible)
350 dest.matrix() += actualAlpha * MappedDest(actualDestPtr, dest.size());
352 dest = MappedDest(actualDestPtr, dest.size());
360 template <
typename Lhs,
typename Rhs,
typename Dest>
361 static void run(
const Lhs& lhs,
const Rhs& rhs, Dest& dest,
const typename Dest::Scalar& alpha) {
362 typedef typename Lhs::Scalar LhsScalar;
363 typedef typename Rhs::Scalar RhsScalar;
364 typedef typename Dest::Scalar ResScalar;
366 typedef internal::blas_traits<Lhs> LhsBlasTraits;
367 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
368 typedef internal::blas_traits<Rhs> RhsBlasTraits;
369 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
370 typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
372 std::add_const_t<ActualLhsType> actualLhs = LhsBlasTraits::extract(lhs);
373 std::add_const_t<ActualRhsType> actualRhs = RhsBlasTraits::extract(rhs);
375 ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
381 ActualRhsTypeCleaned::InnerStrideAtCompileTime == 1 || ActualRhsTypeCleaned::MaxSizeAtCompileTime == 0
384 gemv_static_vector_if<RhsScalar, ActualRhsTypeCleaned::SizeAtCompileTime,
385 ActualRhsTypeCleaned::MaxSizeAtCompileTime, !DirectlyUseRhs>
388 ei_declare_aligned_stack_constructed_variable(
389 RhsScalar, actualRhsPtr, actualRhs.size(),
390 DirectlyUseRhs ?
const_cast<RhsScalar*
>(actualRhs.data()) : static_rhs.data());
392 if (!DirectlyUseRhs) {
393#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
394 Index size = actualRhs.size();
395 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
397 Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
400 typedef const_blas_data_mapper<LhsScalar, Index, RowMajor> LhsMapper;
401 typedef const_blas_data_mapper<RhsScalar, Index, ColMajor> RhsMapper;
402 general_matrix_vector_product<Index, LhsScalar, LhsMapper,
RowMajor, LhsBlasTraits::NeedToConjugate, RhsScalar,
403 RhsMapper, RhsBlasTraits::NeedToConjugate>::
404 run(actualLhs.rows(), actualLhs.cols(), LhsMapper(actualLhs.data(), actualLhs.outerStride()),
405 RhsMapper(actualRhsPtr, 1), dest.data(),
406 dest.col(0).innerStride(),
414 template <
typename Lhs,
typename Rhs,
typename Dest>
415 static void run(
const Lhs& lhs,
const Rhs& rhs, Dest& dest,
const typename Dest::Scalar& alpha) {
416 EIGEN_STATIC_ASSERT((!nested_eval<Lhs, 1>::Evaluate),
417 EIGEN_INTERNAL_COMPILATION_ERROR_OR_YOU_MADE_A_PROGRAMMING_MISTAKE);
420 typename nested_eval<Rhs, 1>::type actual_rhs(rhs);
421 const Index size = rhs.rows();
422 for (Index k = 0; k < size; ++k) dest += (alpha * actual_rhs.coeff(k)) * lhs.col(k);
428 template <
typename Lhs,
typename Rhs,
typename Dest>
429 static void run(
const Lhs& lhs,
const Rhs& rhs, Dest& dest,
const typename Dest::Scalar& alpha) {
430 EIGEN_STATIC_ASSERT((!nested_eval<Lhs, 1>::Evaluate),
431 EIGEN_INTERNAL_COMPILATION_ERROR_OR_YOU_MADE_A_PROGRAMMING_MISTAKE);
432 typename nested_eval<Rhs, Lhs::RowsAtCompileTime>::type actual_rhs(rhs);
433 const Index rows = dest.rows();
434 for (Index i = 0; i < rows; ++i)
435 dest.coeffRef(i) += alpha * (lhs.row(i).cwiseProduct(actual_rhs.transpose())).sum();
451template <
typename Derived>
452template <
typename OtherDerived>
460 ProductIsValid = Derived::ColsAtCompileTime ==
Dynamic || OtherDerived::RowsAtCompileTime ==
Dynamic ||
461 int(Derived::ColsAtCompileTime) == int(OtherDerived::RowsAtCompileTime),
462 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
463 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived, OtherDerived)
469 ProductIsValid || !(AreVectors && SameSizes),
470 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
471 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
472 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
473 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
474#ifdef EIGEN_DEBUG_PRODUCT
475 internal::product_type<Derived, OtherDerived>::debug();
492template <
typename Derived>
493template <
typename OtherDerived>
497 ProductIsValid = Derived::ColsAtCompileTime ==
Dynamic || OtherDerived::RowsAtCompileTime ==
Dynamic ||
498 int(Derived::ColsAtCompileTime) == int(OtherDerived::RowsAtCompileTime),
499 AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
500 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived, OtherDerived)
506 ProductIsValid || !(AreVectors && SameSizes),
507 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
508 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
509 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
510 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)