Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
SparseSparseProductWithPruning.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008-2014 Gael Guennebaud <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
11#define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20// perform a pseudo in-place sparse * sparse product assuming all matrices are col major
21template <typename Lhs, typename Rhs, typename ResultType>
22static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res,
23 const typename ResultType::RealScalar& tolerance) {
24 // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
25
26 typedef typename remove_all_t<Rhs>::Scalar RhsScalar;
27 typedef typename remove_all_t<ResultType>::Scalar ResScalar;
28 typedef typename remove_all_t<Lhs>::StorageIndex StorageIndex;
29
30 // make sure to call innerSize/outerSize since we fake the storage order.
31 Index rows = lhs.innerSize();
32 Index cols = rhs.outerSize();
33 // Index size = lhs.outerSize();
34 eigen_assert(lhs.outerSize() == rhs.innerSize());
35
36 // allocate a temporary buffer
37 AmbiVector<ResScalar, StorageIndex> tempVector(rows);
38
39 // mimics a resizeByInnerOuter:
40 if (ResultType::IsRowMajor)
41 res.resize(cols, rows);
42 else
43 res.resize(rows, cols);
44
45 evaluator<Lhs> lhsEval(lhs);
46 evaluator<Rhs> rhsEval(rhs);
47
48 // estimate the number of non zero entries
49 // given a rhs column containing Y non zeros, we assume that the respective Y columns
50 // of the lhs differs in average of one non zeros, thus the number of non zeros for
51 // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
52 // per column of the lhs.
53 // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
54 Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate();
55
56 res.reserve(estimated_nnz_prod);
57 double ratioColRes = double(estimated_nnz_prod) / (double(lhs.rows()) * double(rhs.cols()));
58 for (Index j = 0; j < cols; ++j) {
59 // FIXME:
60 // double ratioColRes = (double(rhs.innerVector(j).nonZeros()) +
61 // double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
62 // let's do a more accurate determination of the nnz ratio for the current column j of res
63 tempVector.init(ratioColRes);
64 tempVector.setZero();
65 for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) {
66 // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
67 tempVector.restart();
68 RhsScalar x = rhsIt.value();
69 for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt) {
70 tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
71 }
72 }
73 res.startVec(j);
74 for (typename AmbiVector<ResScalar, StorageIndex>::Iterator it(tempVector, tolerance); it; ++it)
75 res.insertBackByOuterInner(j, it.index()) = it.value();
76 }
77 res.finalize();
78}
79
80template <typename Lhs, typename Rhs, typename ResultType, int LhsStorageOrder = traits<Lhs>::Flags & RowMajorBit,
81 int RhsStorageOrder = traits<Rhs>::Flags & RowMajorBit,
82 int ResStorageOrder = traits<ResultType>::Flags & RowMajorBit>
83struct sparse_sparse_product_with_pruning_selector;
84
85template <typename Lhs, typename Rhs, typename ResultType>
86struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, ColMajor, ColMajor, ColMajor> {
87 typedef typename ResultType::RealScalar RealScalar;
88
89 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
90 remove_all_t<ResultType> res_(res.rows(), res.cols());
91 internal::sparse_sparse_product_with_pruning_impl<Lhs, Rhs, ResultType>(lhs, rhs, res_, tolerance);
92 res.swap(res_);
93 }
94};
95
96template <typename Lhs, typename Rhs, typename ResultType>
97struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, ColMajor, ColMajor, RowMajor> {
98 typedef typename ResultType::RealScalar RealScalar;
99 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
100 // we need a col-major matrix to hold the result
101 typedef SparseMatrix<typename ResultType::Scalar, ColMajor, typename ResultType::StorageIndex> SparseTemporaryType;
102 SparseTemporaryType res_(res.rows(), res.cols());
103 internal::sparse_sparse_product_with_pruning_impl<Lhs, Rhs, SparseTemporaryType>(lhs, rhs, res_, tolerance);
104 res = res_;
105 }
106};
107
108template <typename Lhs, typename Rhs, typename ResultType>
109struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, RowMajor, RowMajor, RowMajor> {
110 typedef typename ResultType::RealScalar RealScalar;
111 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
112 // let's transpose the product to get a column x column product
113 remove_all_t<ResultType> res_(res.rows(), res.cols());
114 internal::sparse_sparse_product_with_pruning_impl<Rhs, Lhs, ResultType>(rhs, lhs, res_, tolerance);
115 res.swap(res_);
116 }
117};
118
119template <typename Lhs, typename Rhs, typename ResultType>
120struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, RowMajor, RowMajor, ColMajor> {
121 typedef typename ResultType::RealScalar RealScalar;
122 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
123 typedef SparseMatrix<typename Lhs::Scalar, ColMajor, typename Lhs::StorageIndex> ColMajorMatrixLhs;
124 typedef SparseMatrix<typename Rhs::Scalar, ColMajor, typename Lhs::StorageIndex> ColMajorMatrixRhs;
125 ColMajorMatrixLhs colLhs(lhs);
126 ColMajorMatrixRhs colRhs(rhs);
127 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs, ColMajorMatrixRhs, ResultType>(colLhs, colRhs,
128 res, tolerance);
129
130 // let's transpose the product to get a column x column product
131 // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
132 // SparseTemporaryType res_(res.cols(), res.rows());
133 // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, res_);
134 // res = res_.transpose();
135 }
136};
137
138template <typename Lhs, typename Rhs, typename ResultType>
139struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, ColMajor, RowMajor, RowMajor> {
140 typedef typename ResultType::RealScalar RealScalar;
141 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
142 typedef SparseMatrix<typename Lhs::Scalar, RowMajor, typename Lhs::StorageIndex> RowMajorMatrixLhs;
143 RowMajorMatrixLhs rowLhs(lhs);
144 sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs, Rhs, ResultType, RowMajor, RowMajor>(rowLhs, rhs,
145 res, tolerance);
146 }
147};
148
149template <typename Lhs, typename Rhs, typename ResultType>
150struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, RowMajor, ColMajor, RowMajor> {
151 typedef typename ResultType::RealScalar RealScalar;
152 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
153 typedef SparseMatrix<typename Rhs::Scalar, RowMajor, typename Lhs::StorageIndex> RowMajorMatrixRhs;
154 RowMajorMatrixRhs rowRhs(rhs);
155 sparse_sparse_product_with_pruning_selector<Lhs, RowMajorMatrixRhs, ResultType, RowMajor, RowMajor, RowMajor>(
156 lhs, rowRhs, res, tolerance);
157 }
158};
159
160template <typename Lhs, typename Rhs, typename ResultType>
161struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, ColMajor, RowMajor, ColMajor> {
162 typedef typename ResultType::RealScalar RealScalar;
163 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
164 typedef SparseMatrix<typename Rhs::Scalar, ColMajor, typename Lhs::StorageIndex> ColMajorMatrixRhs;
165 ColMajorMatrixRhs colRhs(rhs);
166 internal::sparse_sparse_product_with_pruning_impl<Lhs, ColMajorMatrixRhs, ResultType>(lhs, colRhs, res, tolerance);
167 }
168};
169
170template <typename Lhs, typename Rhs, typename ResultType>
171struct sparse_sparse_product_with_pruning_selector<Lhs, Rhs, ResultType, RowMajor, ColMajor, ColMajor> {
172 typedef typename ResultType::RealScalar RealScalar;
173 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) {
174 typedef SparseMatrix<typename Lhs::Scalar, ColMajor, typename Lhs::StorageIndex> ColMajorMatrixLhs;
175 ColMajorMatrixLhs colLhs(lhs);
176 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs, Rhs, ResultType>(colLhs, rhs, res, tolerance);
177 }
178};
179
180} // end namespace internal
181
182} // end namespace Eigen
183
184#endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition Core:137