Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
SparseDenseProduct.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008-2015 Gael Guennebaud <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SPARSEDENSEPRODUCT_H
11#define EIGEN_SPARSEDENSEPRODUCT_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20template <>
21struct product_promote_storage_type<Sparse, Dense, OuterProduct> {
22 typedef Sparse ret;
23};
24template <>
25struct product_promote_storage_type<Dense, Sparse, OuterProduct> {
26 typedef Sparse ret;
27};
28
29template <typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType,
30 int LhsStorageOrder = ((SparseLhsType::Flags & RowMajorBit) == RowMajorBit) ? RowMajor : ColMajor,
31 bool ColPerCol = ((DenseRhsType::Flags & RowMajorBit) == 0) || DenseRhsType::ColsAtCompileTime == 1>
32struct sparse_time_dense_product_impl;
33
34template <typename SparseLhsType, typename DenseRhsType, typename DenseResType>
35struct sparse_time_dense_product_impl<SparseLhsType, DenseRhsType, DenseResType, typename DenseResType::Scalar,
36 RowMajor, true> {
37 typedef internal::remove_all_t<SparseLhsType> Lhs;
38 typedef internal::remove_all_t<DenseRhsType> Rhs;
39 typedef internal::remove_all_t<DenseResType> Res;
40 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
41 typedef evaluator<Lhs> LhsEval;
42 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res,
43 const typename Res::Scalar& alpha) {
44 LhsEval lhsEval(lhs);
45
46 Index n = lhs.outerSize();
47#ifdef EIGEN_HAS_OPENMP
48 Index threads = Eigen::nbThreads();
49#endif
50
51 for (Index c = 0; c < rhs.cols(); ++c) {
52#ifdef EIGEN_HAS_OPENMP
53 // This 20000 threshold has been found experimentally on 2D and 3D Poisson problems.
54 // It basically represents the minimal amount of work to be done to be worth it.
55 if (threads > 1 && lhsEval.nonZerosEstimate() > 20000) {
56#pragma omp parallel for schedule(dynamic, (n + threads * 4 - 1) / (threads * 4)) num_threads(threads)
57 for (Index i = 0; i < n; ++i) processRow(lhsEval, rhs, res, alpha, i, c);
58 } else
59#endif
60 {
61 for (Index i = 0; i < n; ++i) processRow(lhsEval, rhs, res, alpha, i, c);
62 }
63 }
64 }
65
66 static void processRow(const LhsEval& lhsEval, const DenseRhsType& rhs, DenseResType& res,
67 const typename Res::Scalar& alpha, Index i, Index col) {
68 // Two accumulators, which breaks the dependency chain on the accumulator
69 // and allows more instruction-level parallelism in the following loop
70 typename Res::Scalar tmp_a(0);
71 typename Res::Scalar tmp_b(0);
72 for (LhsInnerIterator it(lhsEval, i); it; ++it) {
73 tmp_a += it.value() * rhs.coeff(it.index(), col);
74 ++it;
75 if (it) {
76 tmp_b += it.value() * rhs.coeff(it.index(), col);
77 }
78 }
79 res.coeffRef(i, col) += alpha * (tmp_a + tmp_b);
80 }
81};
82
83// FIXME: what is the purpose of the following specialization? Is it for the BlockedSparse format?
84// -> let's disable it for now as it is conflicting with generic scalar*matrix and matrix*scalar operators
85// template<typename T1, typename T2/*, int Options_, typename StrideType_*/>
86// struct ScalarBinaryOpTraits<T1, Ref<T2/*, Options_, StrideType_*/> >
87// {
88// enum {
89// Defined = 1
90// };
91// typedef typename CwiseUnaryOp<scalar_multiple2_op<T1, typename T2::Scalar>, T2>::PlainObject ReturnType;
92// };
93
94template <typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType>
95struct sparse_time_dense_product_impl<SparseLhsType, DenseRhsType, DenseResType, AlphaType, ColMajor, true> {
96 typedef internal::remove_all_t<SparseLhsType> Lhs;
97 typedef internal::remove_all_t<DenseRhsType> Rhs;
98 typedef internal::remove_all_t<DenseResType> Res;
99 typedef evaluator<Lhs> LhsEval;
100 typedef typename LhsEval::InnerIterator LhsInnerIterator;
101 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) {
102 LhsEval lhsEval(lhs);
103 for (Index c = 0; c < rhs.cols(); ++c) {
104 for (Index j = 0; j < lhs.outerSize(); ++j) {
105 // typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
106 typename ScalarBinaryOpTraits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j, c));
107 for (LhsInnerIterator it(lhsEval, j); it; ++it) res.coeffRef(it.index(), c) += it.value() * rhs_j;
108 }
109 }
110 }
111};
112
113template <typename SparseLhsType, typename DenseRhsType, typename DenseResType>
114struct sparse_time_dense_product_impl<SparseLhsType, DenseRhsType, DenseResType, typename DenseResType::Scalar,
115 RowMajor, false> {
116 typedef internal::remove_all_t<SparseLhsType> Lhs;
117 typedef internal::remove_all_t<DenseRhsType> Rhs;
118 typedef internal::remove_all_t<DenseResType> Res;
119 typedef evaluator<Lhs> LhsEval;
120 typedef typename LhsEval::InnerIterator LhsInnerIterator;
121 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res,
122 const typename Res::Scalar& alpha) {
123 Index n = lhs.rows();
124 LhsEval lhsEval(lhs);
125
126#ifdef EIGEN_HAS_OPENMP
127 Index threads = Eigen::nbThreads();
128 // This 20000 threshold has been found experimentally on 2D and 3D Poisson problems.
129 // It basically represents the minimal amount of work to be done to be worth it.
130 if (threads > 1 && lhsEval.nonZerosEstimate() * rhs.cols() > 20000) {
131#pragma omp parallel for schedule(dynamic, (n + threads * 4 - 1) / (threads * 4)) num_threads(threads)
132 for (Index i = 0; i < n; ++i) processRow(lhsEval, rhs, res, alpha, i);
133 } else
134#endif
135 {
136 for (Index i = 0; i < n; ++i) processRow(lhsEval, rhs, res, alpha, i);
137 }
138 }
139
140 static void processRow(const LhsEval& lhsEval, const DenseRhsType& rhs, Res& res, const typename Res::Scalar& alpha,
141 Index i) {
142 typename Res::RowXpr res_i(res.row(i));
143 for (LhsInnerIterator it(lhsEval, i); it; ++it) res_i += (alpha * it.value()) * rhs.row(it.index());
144 }
145};
146
147template <typename SparseLhsType, typename DenseRhsType, typename DenseResType>
148struct sparse_time_dense_product_impl<SparseLhsType, DenseRhsType, DenseResType, typename DenseResType::Scalar,
149 ColMajor, false> {
150 typedef internal::remove_all_t<SparseLhsType> Lhs;
151 typedef internal::remove_all_t<DenseRhsType> Rhs;
152 typedef internal::remove_all_t<DenseResType> Res;
153 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
154 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res,
155 const typename Res::Scalar& alpha) {
156 evaluator<Lhs> lhsEval(lhs);
157 for (Index j = 0; j < lhs.outerSize(); ++j) {
158 typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
159 for (LhsInnerIterator it(lhsEval, j); it; ++it) res.row(it.index()) += (alpha * it.value()) * rhs_j;
160 }
161 }
162};
163
164template <typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType>
165inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res,
166 const AlphaType& alpha) {
167 sparse_time_dense_product_impl<SparseLhsType, DenseRhsType, DenseResType, AlphaType>::run(lhs, rhs, res, alpha);
168}
169
170} // end namespace internal
171
172namespace internal {
173
174template <typename Lhs, typename Rhs, int ProductType>
175struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType>
176 : generic_product_impl_base<Lhs, Rhs, generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> > {
177 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
178
179 template <typename Dest>
180 static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
181 typedef typename nested_eval<Lhs, ((Rhs::Flags & RowMajorBit) == 0) ? 1 : Rhs::ColsAtCompileTime>::type LhsNested;
182 typedef typename nested_eval<Rhs, ((Lhs::Flags & RowMajorBit) == 0) ? 1 : Dynamic>::type RhsNested;
183 LhsNested lhsNested(lhs);
184 RhsNested rhsNested(rhs);
185 internal::sparse_time_dense_product(lhsNested, rhsNested, dst, alpha);
186 }
187};
188
189template <typename Lhs, typename Rhs, int ProductType>
190struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, DenseShape, ProductType>
191 : generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> {};
192
193template <typename Lhs, typename Rhs, int ProductType>
194struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType>
195 : generic_product_impl_base<Lhs, Rhs, generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> > {
196 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
197
198 template <typename Dst>
199 static void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
200 typedef typename nested_eval<Lhs, ((Rhs::Flags & RowMajorBit) == 0) ? Dynamic : 1>::type LhsNested;
201 typedef typename nested_eval<Rhs, ((Lhs::Flags & RowMajorBit) == RowMajorBit) ? 1 : Lhs::RowsAtCompileTime>::type
202 RhsNested;
203 LhsNested lhsNested(lhs);
204 RhsNested rhsNested(rhs);
205
206 // transpose everything
207 Transpose<Dst> dstT(dst);
208 internal::sparse_time_dense_product(rhsNested.transpose(), lhsNested.transpose(), dstT, alpha);
209 }
210};
211
212template <typename Lhs, typename Rhs, int ProductType>
213struct generic_product_impl<Lhs, Rhs, DenseShape, SparseTriangularShape, ProductType>
214 : generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> {};
215
216template <typename LhsT, typename RhsT, bool NeedToTranspose>
217struct sparse_dense_outer_product_evaluator {
218 protected:
219 typedef std::conditional_t<NeedToTranspose, RhsT, LhsT> Lhs1;
220 typedef std::conditional_t<NeedToTranspose, LhsT, RhsT> ActualRhs;
221 typedef Product<LhsT, RhsT, DefaultProduct> ProdXprType;
222
223 // if the actual left-hand side is a dense vector,
224 // then build a sparse-view so that we can seamlessly iterate over it.
225 typedef std::conditional_t<is_same<typename internal::traits<Lhs1>::StorageKind, Sparse>::value, Lhs1,
226 SparseView<Lhs1> >
227 ActualLhs;
228 typedef std::conditional_t<is_same<typename internal::traits<Lhs1>::StorageKind, Sparse>::value, Lhs1 const&,
229 SparseView<Lhs1> >
230 LhsArg;
231
232 typedef evaluator<ActualLhs> LhsEval;
233 typedef evaluator<ActualRhs> RhsEval;
234 typedef typename evaluator<ActualLhs>::InnerIterator LhsIterator;
235 typedef typename ProdXprType::Scalar Scalar;
236
237 public:
238 enum { Flags = NeedToTranspose ? RowMajorBit : 0, CoeffReadCost = HugeCost };
239
240 class InnerIterator : public LhsIterator {
241 public:
242 InnerIterator(const sparse_dense_outer_product_evaluator& xprEval, Index outer)
243 : LhsIterator(xprEval.m_lhsXprImpl, 0),
244 m_outer(outer),
245 m_empty(false),
246 m_factor(get(xprEval.m_rhsXprImpl, outer, typename internal::traits<ActualRhs>::StorageKind())) {}
247
248 EIGEN_STRONG_INLINE Index outer() const { return m_outer; }
249 EIGEN_STRONG_INLINE Index row() const { return NeedToTranspose ? m_outer : LhsIterator::index(); }
250 EIGEN_STRONG_INLINE Index col() const { return NeedToTranspose ? LhsIterator::index() : m_outer; }
251
252 EIGEN_STRONG_INLINE Scalar value() const { return LhsIterator::value() * m_factor; }
253 EIGEN_STRONG_INLINE operator bool() const { return LhsIterator::operator bool() && (!m_empty); }
254
255 protected:
256 Scalar get(const RhsEval& rhs, Index outer, Dense = Dense()) const { return rhs.coeff(outer); }
257
258 Scalar get(const RhsEval& rhs, Index outer, Sparse = Sparse()) {
259 typename RhsEval::InnerIterator it(rhs, outer);
260 if (it && it.index() == 0 && it.value() != Scalar(0)) return it.value();
261 m_empty = true;
262 return Scalar(0);
263 }
264
265 Index m_outer;
266 bool m_empty;
267 Scalar m_factor;
268 };
269
270 sparse_dense_outer_product_evaluator(const Lhs1& lhs, const ActualRhs& rhs)
271 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) {
272 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
273 }
274
275 // transpose case
276 sparse_dense_outer_product_evaluator(const ActualRhs& rhs, const Lhs1& lhs)
277 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) {
278 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
279 }
280
281 protected:
282 const LhsArg m_lhs;
283 evaluator<ActualLhs> m_lhsXprImpl;
284 evaluator<ActualRhs> m_rhsXprImpl;
285};
286
287// sparse * dense outer product
288template <typename Lhs, typename Rhs>
289struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape>
290 : sparse_dense_outer_product_evaluator<Lhs, Rhs, Lhs::IsRowMajor> {
291 typedef sparse_dense_outer_product_evaluator<Lhs, Rhs, Lhs::IsRowMajor> Base;
292
293 typedef Product<Lhs, Rhs> XprType;
294 typedef typename XprType::PlainObject PlainObject;
295
296 explicit product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs()) {}
297};
298
299template <typename Lhs, typename Rhs>
300struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape>
301 : sparse_dense_outer_product_evaluator<Lhs, Rhs, Rhs::IsRowMajor> {
302 typedef sparse_dense_outer_product_evaluator<Lhs, Rhs, Rhs::IsRowMajor> Base;
303
304 typedef Product<Lhs, Rhs> XprType;
305 typedef typename XprType::PlainObject PlainObject;
306
307 explicit product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs()) {}
308};
309
310} // end namespace internal
311
312} // end namespace Eigen
313
314#endif // EIGEN_SPARSEDENSEPRODUCT_H
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition Core:137
const int HugeCost
Definition Constants.h:48
const int Dynamic
Definition Constants.h:25