Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
ConservativeSparseSparseProduct.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008-2015 Gael Guennebaud <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H
11#define EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20template <typename Lhs, typename Rhs, typename ResultType>
21static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res,
22 bool sortedInsertion = false) {
23 typedef typename remove_all_t<Lhs>::Scalar LhsScalar;
24 typedef typename remove_all_t<Rhs>::Scalar RhsScalar;
25 typedef typename remove_all_t<ResultType>::Scalar ResScalar;
26
27 // make sure to call innerSize/outerSize since we fake the storage order.
28 Index rows = lhs.innerSize();
29 Index cols = rhs.outerSize();
30 eigen_assert(lhs.outerSize() == rhs.innerSize());
31
32 ei_declare_aligned_stack_constructed_variable(bool, mask, rows, 0);
33 ei_declare_aligned_stack_constructed_variable(ResScalar, values, rows, 0);
34 ei_declare_aligned_stack_constructed_variable(Index, indices, rows, 0);
35
36 std::memset(mask, 0, sizeof(bool) * rows);
37
38 evaluator<Lhs> lhsEval(lhs);
39 evaluator<Rhs> rhsEval(rhs);
40
41 // estimate the number of non zero entries
42 // given a rhs column containing Y non zeros, we assume that the respective Y columns
43 // of the lhs differs in average of one non zeros, thus the number of non zeros for
44 // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
45 // per column of the lhs.
46 // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
47 Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate();
48
49 res.setZero();
50 res.reserve(Index(estimated_nnz_prod));
51 // we compute each column of the result, one after the other
52 for (Index j = 0; j < cols; ++j) {
53 res.startVec(j);
54 Index nnz = 0;
55 for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) {
56 RhsScalar y = rhsIt.value();
57 Index k = rhsIt.index();
58 for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt) {
59 Index i = lhsIt.index();
60 LhsScalar x = lhsIt.value();
61 if (!mask[i]) {
62 mask[i] = true;
63 values[i] = x * y;
64 indices[nnz] = i;
65 ++nnz;
66 } else
67 values[i] += x * y;
68 }
69 }
70 if (!sortedInsertion) {
71 // unordered insertion
72 for (Index k = 0; k < nnz; ++k) {
73 Index i = indices[k];
74 res.insertBackByOuterInnerUnordered(j, i) = values[i];
75 mask[i] = false;
76 }
77 } else {
78 // alternative ordered insertion code:
79 const Index t200 = rows / 11; // 11 == (log2(200)*1.39)
80 const Index t = (rows * 100) / 139;
81
82 // FIXME reserve nnz non zeros
83 // FIXME implement faster sorting algorithms for very small nnz
84 // if the result is sparse enough => use a quick sort
85 // otherwise => loop through the entire vector
86 // In order to avoid to perform an expensive log2 when the
87 // result is clearly very sparse we use a linear bound up to 200.
88 if ((nnz < 200 && nnz < t200) || nnz * numext::log2(int(nnz)) < t) {
89 if (nnz > 1) std::sort(indices, indices + nnz);
90 for (Index k = 0; k < nnz; ++k) {
91 Index i = indices[k];
92 res.insertBackByOuterInner(j, i) = values[i];
93 mask[i] = false;
94 }
95 } else {
96 // dense path
97 for (Index i = 0; i < rows; ++i) {
98 if (mask[i]) {
99 mask[i] = false;
100 res.insertBackByOuterInner(j, i) = values[i];
101 }
102 }
103 }
104 }
105 }
106 res.finalize();
107}
108
109} // end namespace internal
110
111namespace internal {
112
113// Helper template to generate new sparse matrix types
114template <class Source, int Order>
115using WithStorageOrder = SparseMatrix<typename Source::Scalar, Order, typename Source::StorageIndex>;
116
117template <typename Lhs, typename Rhs, typename ResultType,
118 int LhsStorageOrder = (traits<Lhs>::Flags & RowMajorBit) ? RowMajor : ColMajor,
119 int RhsStorageOrder = (traits<Rhs>::Flags & RowMajorBit) ? RowMajor : ColMajor,
120 int ResStorageOrder = (traits<ResultType>::Flags & RowMajorBit) ? RowMajor : ColMajor>
121struct conservative_sparse_sparse_product_selector;
122
123template <typename Lhs, typename Rhs, typename ResultType>
124struct conservative_sparse_sparse_product_selector<Lhs, Rhs, ResultType, ColMajor, ColMajor, ColMajor> {
125 typedef remove_all_t<Lhs> LhsCleaned;
126 typedef typename LhsCleaned::Scalar Scalar;
127
128 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) {
129 using RowMajorMatrix = WithStorageOrder<ResultType, RowMajor>;
130 using ColMajorMatrixAux = WithStorageOrder<ResultType, ColMajor>;
131
132 // If the result is tall and thin (in the extreme case a column vector)
133 // then it is faster to sort the coefficients inplace instead of transposing twice.
134 // FIXME, the following heuristic is probably not very good.
135 if (lhs.rows() > rhs.cols()) {
136 using ColMajorMatrix = typename sparse_eval<ColMajorMatrixAux, ResultType::RowsAtCompileTime,
137 ResultType::ColsAtCompileTime, ColMajorMatrixAux::Flags>::type;
138 ColMajorMatrix resCol(lhs.rows(), rhs.cols());
139 // perform sorted insertion
140 internal::conservative_sparse_sparse_product_impl<Lhs, Rhs, ColMajorMatrix>(lhs, rhs, resCol, true);
141 res = resCol.markAsRValue();
142 } else {
143 ColMajorMatrixAux resCol(lhs.rows(), rhs.cols());
144 // resort to transpose to sort the entries
145 internal::conservative_sparse_sparse_product_impl<Lhs, Rhs, ColMajorMatrixAux>(lhs, rhs, resCol, false);
146 RowMajorMatrix resRow(resCol);
147 res = resRow.markAsRValue();
148 }
149 }
150};
151
152template <typename Lhs, typename Rhs, typename ResultType>
153struct conservative_sparse_sparse_product_selector<Lhs, Rhs, ResultType, RowMajor, ColMajor, ColMajor> {
154 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) {
155 using RowMajorRhs = WithStorageOrder<Rhs, RowMajor>;
156 using RowMajorRes = WithStorageOrder<ResultType, RowMajor>;
157 RowMajorRhs rhsRow = rhs;
158 RowMajorRes resRow(lhs.rows(), rhs.cols());
159 internal::conservative_sparse_sparse_product_impl<RowMajorRhs, Lhs, RowMajorRes>(rhsRow, lhs, resRow);
160 res = resRow;
161 }
162};
163
164template <typename Lhs, typename Rhs, typename ResultType>
165struct conservative_sparse_sparse_product_selector<Lhs, Rhs, ResultType, ColMajor, RowMajor, ColMajor> {
166 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) {
167 using RowMajorLhs = WithStorageOrder<Lhs, RowMajor>;
168 using RowMajorRes = WithStorageOrder<ResultType, RowMajor>;
169 RowMajorLhs lhsRow = lhs;
170 RowMajorRes resRow(lhs.rows(), rhs.cols());
171 internal::conservative_sparse_sparse_product_impl<Rhs, RowMajorLhs, RowMajorRes>(rhs, lhsRow, resRow);
172 res = resRow;
173 }
174};
175
176template <typename Lhs, typename Rhs, typename ResultType>
177struct conservative_sparse_sparse_product_selector<Lhs, Rhs, ResultType, RowMajor, RowMajor, ColMajor> {
178 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) {
179 using RowMajorRes = WithStorageOrder<ResultType, RowMajor>;
180 RowMajorRes resRow(lhs.rows(), rhs.cols());
181 internal::conservative_sparse_sparse_product_impl<Rhs, Lhs, RowMajorRes>(rhs, lhs, resRow);
182 res = resRow;
183 }
184};
185
186template <typename Lhs, typename Rhs, typename ResultType>
187struct conservative_sparse_sparse_product_selector<Lhs, Rhs, ResultType, ColMajor, ColMajor, RowMajor> {
188 typedef typename traits<remove_all_t<Lhs>>::Scalar Scalar;
189
190 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) {
191 using ColMajorRes = WithStorageOrder<ResultType, ColMajor>;
192 ColMajorRes resCol(lhs.rows(), rhs.cols());
193 internal::conservative_sparse_sparse_product_impl<Lhs, Rhs, ColMajorRes>(lhs, rhs, resCol);
194 res = resCol;
195 }
196};
197
198template <typename Lhs, typename Rhs, typename ResultType>
199struct conservative_sparse_sparse_product_selector<Lhs, Rhs, ResultType, RowMajor, ColMajor, RowMajor> {
200 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) {
201 using ColMajorLhs = WithStorageOrder<Lhs, ColMajor>;
202 using ColMajorRes = WithStorageOrder<ResultType, ColMajor>;
203 ColMajorLhs lhsCol = lhs;
204 ColMajorRes resCol(lhs.rows(), rhs.cols());
205 internal::conservative_sparse_sparse_product_impl<ColMajorLhs, Rhs, ColMajorRes>(lhsCol, rhs, resCol);
206 res = resCol;
207 }
208};
209
210template <typename Lhs, typename Rhs, typename ResultType>
211struct conservative_sparse_sparse_product_selector<Lhs, Rhs, ResultType, ColMajor, RowMajor, RowMajor> {
212 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) {
213 using ColMajorRhs = WithStorageOrder<Rhs, ColMajor>;
214 using ColMajorRes = WithStorageOrder<ResultType, ColMajor>;
215 ColMajorRhs rhsCol = rhs;
216 ColMajorRes resCol(lhs.rows(), rhs.cols());
217 internal::conservative_sparse_sparse_product_impl<Lhs, ColMajorRhs, ColMajorRes>(lhs, rhsCol, resCol);
218 res = resCol;
219 }
220};
221
222template <typename Lhs, typename Rhs, typename ResultType>
223struct conservative_sparse_sparse_product_selector<Lhs, Rhs, ResultType, RowMajor, RowMajor, RowMajor> {
224 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) {
225 using ColMajorRes = WithStorageOrder<ResultType, ColMajor>;
226 using RowMajorRes = WithStorageOrder<ResultType, RowMajor>;
227 RowMajorRes resRow(lhs.rows(), rhs.cols());
228 internal::conservative_sparse_sparse_product_impl<Rhs, Lhs, RowMajorRes>(rhs, lhs, resRow);
229 // sort the non zeros:
230 ColMajorRes resCol(resRow);
231 res = resCol;
232 }
233};
234
235} // end namespace internal
236
237namespace internal {
238
239template <typename Lhs, typename Rhs, typename ResultType>
240static void sparse_sparse_to_dense_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res) {
241 typedef typename remove_all_t<Lhs>::Scalar LhsScalar;
242 typedef typename remove_all_t<Rhs>::Scalar RhsScalar;
243 Index cols = rhs.outerSize();
244 eigen_assert(lhs.outerSize() == rhs.innerSize());
245
246 evaluator<Lhs> lhsEval(lhs);
247 evaluator<Rhs> rhsEval(rhs);
248
249 for (Index j = 0; j < cols; ++j) {
250 for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) {
251 RhsScalar y = rhsIt.value();
252 Index k = rhsIt.index();
253 for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt) {
254 Index i = lhsIt.index();
255 LhsScalar x = lhsIt.value();
256 res.coeffRef(i, j) += x * y;
257 }
258 }
259 }
260}
261
262} // end namespace internal
263
264namespace internal {
265
266template <typename Lhs, typename Rhs, typename ResultType,
267 int LhsStorageOrder = (traits<Lhs>::Flags & RowMajorBit) ? RowMajor : ColMajor,
268 int RhsStorageOrder = (traits<Rhs>::Flags & RowMajorBit) ? RowMajor : ColMajor>
269struct sparse_sparse_to_dense_product_selector;
270
271template <typename Lhs, typename Rhs, typename ResultType>
272struct sparse_sparse_to_dense_product_selector<Lhs, Rhs, ResultType, ColMajor, ColMajor> {
273 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) {
274 internal::sparse_sparse_to_dense_product_impl<Lhs, Rhs, ResultType>(lhs, rhs, res);
275 }
276};
277
278template <typename Lhs, typename Rhs, typename ResultType>
279struct sparse_sparse_to_dense_product_selector<Lhs, Rhs, ResultType, RowMajor, ColMajor> {
280 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) {
281 using ColMajorLhs = WithStorageOrder<Lhs, ColMajor>;
282 ColMajorLhs lhsCol(lhs);
283 internal::sparse_sparse_to_dense_product_impl<ColMajorLhs, Rhs, ResultType>(lhsCol, rhs, res);
284 }
285};
286
287template <typename Lhs, typename Rhs, typename ResultType>
288struct sparse_sparse_to_dense_product_selector<Lhs, Rhs, ResultType, ColMajor, RowMajor> {
289 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) {
290 using ColMajorRhs = WithStorageOrder<Rhs, ColMajor>;
291 ColMajorRhs rhsCol(rhs);
292 internal::sparse_sparse_to_dense_product_impl<Lhs, ColMajorRhs, ResultType>(lhs, rhsCol, res);
293 }
294};
295
296template <typename Lhs, typename Rhs, typename ResultType>
297struct sparse_sparse_to_dense_product_selector<Lhs, Rhs, ResultType, RowMajor, RowMajor> {
298 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) {
299 Transpose<ResultType> trRes(res);
300 internal::sparse_sparse_to_dense_product_impl<Rhs, Lhs, Transpose<ResultType>>(rhs, lhs, trRes);
301 }
302};
303
304} // end namespace internal
305
306} // end namespace Eigen
307
308#endif // EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition Core:137
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:83