10#ifndef EIGEN_SPARSEDENSEPRODUCT_H
11#define EIGEN_SPARSEDENSEPRODUCT_H
14#include "./InternalHeaderCheck.h"
21struct product_promote_storage_type<Sparse, Dense, OuterProduct> {
25struct product_promote_storage_type<Dense, Sparse, OuterProduct> {
29template <
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType,
typename AlphaType,
31 bool ColPerCol = ((DenseRhsType::Flags &
RowMajorBit) == 0) || DenseRhsType::ColsAtCompileTime == 1>
32struct sparse_time_dense_product_impl;
34template <
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType>
35struct sparse_time_dense_product_impl<SparseLhsType, DenseRhsType, DenseResType, typename DenseResType::Scalar,
37 typedef internal::remove_all_t<SparseLhsType> Lhs;
38 typedef internal::remove_all_t<DenseRhsType> Rhs;
39 typedef internal::remove_all_t<DenseResType> Res;
40 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
41 typedef evaluator<Lhs> LhsEval;
42 static void run(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType& res,
43 const typename Res::Scalar& alpha) {
46 Index n = lhs.outerSize();
47#ifdef EIGEN_HAS_OPENMP
48 Index threads = Eigen::nbThreads();
51 for (Index c = 0; c < rhs.cols(); ++c) {
52#ifdef EIGEN_HAS_OPENMP
55 if (threads > 1 && lhsEval.nonZerosEstimate() > 20000) {
56#pragma omp parallel for schedule(dynamic, (n + threads * 4 - 1) / (threads * 4)) num_threads(threads)
57 for (Index i = 0; i < n; ++i) processRow(lhsEval, rhs, res, alpha, i, c);
61 for (Index i = 0; i < n; ++i) processRow(lhsEval, rhs, res, alpha, i, c);
66 static void processRow(
const LhsEval& lhsEval,
const DenseRhsType& rhs, DenseResType& res,
67 const typename Res::Scalar& alpha, Index i, Index col) {
70 typename Res::Scalar tmp_a(0);
71 typename Res::Scalar tmp_b(0);
72 for (LhsInnerIterator it(lhsEval, i); it; ++it) {
73 tmp_a += it.value() * rhs.coeff(it.index(), col);
76 tmp_b += it.value() * rhs.coeff(it.index(), col);
79 res.coeffRef(i, col) += alpha * (tmp_a + tmp_b);
94template <
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType,
typename AlphaType>
95struct sparse_time_dense_product_impl<SparseLhsType, DenseRhsType, DenseResType, AlphaType,
ColMajor, true> {
96 typedef internal::remove_all_t<SparseLhsType> Lhs;
97 typedef internal::remove_all_t<DenseRhsType> Rhs;
98 typedef internal::remove_all_t<DenseResType> Res;
99 typedef evaluator<Lhs> LhsEval;
100 typedef typename LhsEval::InnerIterator LhsInnerIterator;
101 static void run(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType& res,
const AlphaType& alpha) {
102 LhsEval lhsEval(lhs);
103 for (Index c = 0; c < rhs.cols(); ++c) {
104 for (Index j = 0; j < lhs.outerSize(); ++j) {
106 typename ScalarBinaryOpTraits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j, c));
107 for (LhsInnerIterator it(lhsEval, j); it; ++it) res.coeffRef(it.index(), c) += it.value() * rhs_j;
113template <
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType>
114struct sparse_time_dense_product_impl<SparseLhsType, DenseRhsType, DenseResType, typename DenseResType::Scalar,
116 typedef internal::remove_all_t<SparseLhsType> Lhs;
117 typedef internal::remove_all_t<DenseRhsType> Rhs;
118 typedef internal::remove_all_t<DenseResType> Res;
119 typedef evaluator<Lhs> LhsEval;
120 typedef typename LhsEval::InnerIterator LhsInnerIterator;
121 static void run(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType& res,
122 const typename Res::Scalar& alpha) {
123 Index n = lhs.rows();
124 LhsEval lhsEval(lhs);
126#ifdef EIGEN_HAS_OPENMP
127 Index threads = Eigen::nbThreads();
130 if (threads > 1 && lhsEval.nonZerosEstimate() * rhs.cols() > 20000) {
131#pragma omp parallel for schedule(dynamic, (n + threads * 4 - 1) / (threads * 4)) num_threads(threads)
132 for (Index i = 0; i < n; ++i) processRow(lhsEval, rhs, res, alpha, i);
136 for (Index i = 0; i < n; ++i) processRow(lhsEval, rhs, res, alpha, i);
140 static void processRow(
const LhsEval& lhsEval,
const DenseRhsType& rhs, Res& res,
const typename Res::Scalar& alpha,
142 typename Res::RowXpr res_i(res.row(i));
143 for (LhsInnerIterator it(lhsEval, i); it; ++it) res_i += (alpha * it.value()) * rhs.row(it.index());
147template <
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType>
148struct sparse_time_dense_product_impl<SparseLhsType, DenseRhsType, DenseResType, typename DenseResType::Scalar,
150 typedef internal::remove_all_t<SparseLhsType> Lhs;
151 typedef internal::remove_all_t<DenseRhsType> Rhs;
152 typedef internal::remove_all_t<DenseResType> Res;
153 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
154 static void run(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType& res,
155 const typename Res::Scalar& alpha) {
156 evaluator<Lhs> lhsEval(lhs);
157 for (Index j = 0; j < lhs.outerSize(); ++j) {
158 typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
159 for (LhsInnerIterator it(lhsEval, j); it; ++it) res.row(it.index()) += (alpha * it.value()) * rhs_j;
164template <
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType,
typename AlphaType>
165inline void sparse_time_dense_product(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType& res,
166 const AlphaType& alpha) {
167 sparse_time_dense_product_impl<SparseLhsType, DenseRhsType, DenseResType, AlphaType>::run(lhs, rhs, res, alpha);
174template <
typename Lhs,
typename Rhs,
int ProductType>
175struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType>
176 : generic_product_impl_base<Lhs, Rhs, generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> > {
177 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
179 template <
typename Dest>
180 static void scaleAndAddTo(Dest& dst,
const Lhs& lhs,
const Rhs& rhs,
const Scalar& alpha) {
181 typedef typename nested_eval<Lhs, ((Rhs::Flags &
RowMajorBit) == 0) ? 1 : Rhs::ColsAtCompileTime>::type LhsNested;
182 typedef typename nested_eval<Rhs, ((Lhs::Flags &
RowMajorBit) == 0) ? 1 :
Dynamic>::type RhsNested;
183 LhsNested lhsNested(lhs);
184 RhsNested rhsNested(rhs);
185 internal::sparse_time_dense_product(lhsNested, rhsNested, dst, alpha);
189template <
typename Lhs,
typename Rhs,
int ProductType>
190struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, DenseShape, ProductType>
191 : generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> {};
193template <
typename Lhs,
typename Rhs,
int ProductType>
194struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType>
195 : generic_product_impl_base<Lhs, Rhs, generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> > {
196 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
198 template <
typename Dst>
199 static void scaleAndAddTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs,
const Scalar& alpha) {
200 typedef typename nested_eval<Lhs, ((Rhs::Flags &
RowMajorBit) == 0) ?
Dynamic : 1>::type LhsNested;
201 typedef typename nested_eval<Rhs, ((Lhs::Flags &
RowMajorBit) ==
RowMajorBit) ? 1 : Lhs::RowsAtCompileTime>::type
203 LhsNested lhsNested(lhs);
204 RhsNested rhsNested(rhs);
207 Transpose<Dst> dstT(dst);
208 internal::sparse_time_dense_product(rhsNested.transpose(), lhsNested.transpose(), dstT, alpha);
212template <
typename Lhs,
typename Rhs,
int ProductType>
213struct generic_product_impl<Lhs, Rhs, DenseShape, SparseTriangularShape, ProductType>
214 : generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> {};
216template <
typename LhsT,
typename RhsT,
bool NeedToTranspose>
217struct sparse_dense_outer_product_evaluator {
219 typedef std::conditional_t<NeedToTranspose, RhsT, LhsT> Lhs1;
220 typedef std::conditional_t<NeedToTranspose, LhsT, RhsT> ActualRhs;
221 typedef Product<LhsT, RhsT, DefaultProduct> ProdXprType;
225 typedef std::conditional_t<is_same<typename internal::traits<Lhs1>::StorageKind, Sparse>::value, Lhs1,
228 typedef std::conditional_t<is_same<typename internal::traits<Lhs1>::StorageKind, Sparse>::value, Lhs1
const&,
232 typedef evaluator<ActualLhs> LhsEval;
233 typedef evaluator<ActualRhs> RhsEval;
234 typedef typename evaluator<ActualLhs>::InnerIterator LhsIterator;
235 typedef typename ProdXprType::Scalar Scalar;
240 class InnerIterator :
public LhsIterator {
242 InnerIterator(
const sparse_dense_outer_product_evaluator& xprEval, Index outer)
243 : LhsIterator(xprEval.m_lhsXprImpl, 0),
246 m_factor(get(xprEval.m_rhsXprImpl, outer, typename internal::traits<ActualRhs>::StorageKind())) {}
248 EIGEN_STRONG_INLINE Index outer()
const {
return m_outer; }
249 EIGEN_STRONG_INLINE Index row()
const {
return NeedToTranspose ? m_outer : LhsIterator::index(); }
250 EIGEN_STRONG_INLINE Index col()
const {
return NeedToTranspose ? LhsIterator::index() : m_outer; }
252 EIGEN_STRONG_INLINE Scalar value()
const {
return LhsIterator::value() * m_factor; }
253 EIGEN_STRONG_INLINE
operator bool()
const {
return LhsIterator::operator bool() && (!m_empty); }
256 Scalar get(
const RhsEval& rhs, Index outer, Dense = Dense())
const {
return rhs.coeff(outer); }
258 Scalar get(
const RhsEval& rhs, Index outer, Sparse = Sparse()) {
259 typename RhsEval::InnerIterator it(rhs, outer);
260 if (it && it.index() == 0 && it.value() != Scalar(0))
return it.value();
270 sparse_dense_outer_product_evaluator(
const Lhs1& lhs,
const ActualRhs& rhs)
271 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) {
272 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
276 sparse_dense_outer_product_evaluator(
const ActualRhs& rhs,
const Lhs1& lhs)
277 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) {
278 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
283 evaluator<ActualLhs> m_lhsXprImpl;
284 evaluator<ActualRhs> m_rhsXprImpl;
288template <
typename Lhs,
typename Rhs>
289struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape>
290 : sparse_dense_outer_product_evaluator<Lhs, Rhs, Lhs::IsRowMajor> {
291 typedef sparse_dense_outer_product_evaluator<Lhs, Rhs, Lhs::IsRowMajor> Base;
293 typedef Product<Lhs, Rhs> XprType;
294 typedef typename XprType::PlainObject PlainObject;
296 explicit product_evaluator(
const XprType& xpr) : Base(xpr.lhs(), xpr.rhs()) {}
299template <
typename Lhs,
typename Rhs>
300struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape>
301 : sparse_dense_outer_product_evaluator<Lhs, Rhs, Rhs::IsRowMajor> {
302 typedef sparse_dense_outer_product_evaluator<Lhs, Rhs, Rhs::IsRowMajor> Base;
304 typedef Product<Lhs, Rhs> XprType;
305 typedef typename XprType::PlainObject PlainObject;
307 explicit product_evaluator(
const XprType& xpr) : Base(xpr.lhs(), xpr.rhs()) {}
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition Core:137
const int HugeCost
Definition Constants.h:48
const int Dynamic
Definition Constants.h:25