Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
SparsePermutation.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2012 Gael Guennebaud <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SPARSE_PERMUTATION_H
11#define EIGEN_SPARSE_PERMUTATION_H
12
13// This file implements sparse * permutation products
14
15// IWYU pragma: private
16#include "./InternalHeaderCheck.h"
17
18namespace Eigen {
19
20namespace internal {
21
22template <typename ExpressionType, typename PlainObjectType,
23 bool NeedEval = !is_same<ExpressionType, PlainObjectType>::value>
24struct XprHelper {
25 XprHelper(const ExpressionType& xpr) : m_xpr(xpr) {}
26 inline const PlainObjectType& xpr() const { return m_xpr; }
27 // this is a new PlainObjectType initialized by xpr
28 const PlainObjectType m_xpr;
29};
30template <typename ExpressionType, typename PlainObjectType>
31struct XprHelper<ExpressionType, PlainObjectType, false> {
32 XprHelper(const ExpressionType& xpr) : m_xpr(xpr) {}
33 inline const PlainObjectType& xpr() const { return m_xpr; }
34 // this is a reference to xpr
35 const PlainObjectType& m_xpr;
36};
37
38template <typename PermDerived, bool NeedInverseEval>
39struct PermHelper {
40 using IndicesType = typename PermDerived::IndicesType;
41 using PermutationIndex = typename IndicesType::Scalar;
42 using type = PermutationMatrix<IndicesType::SizeAtCompileTime, IndicesType::MaxSizeAtCompileTime, PermutationIndex>;
43 PermHelper(const PermDerived& perm) : m_perm(perm.inverse()) {}
44 inline const type& perm() const { return m_perm; }
45 // this is a new PermutationMatrix initialized by perm.inverse()
46 const type m_perm;
47};
48template <typename PermDerived>
49struct PermHelper<PermDerived, false> {
50 using type = PermDerived;
51 PermHelper(const PermDerived& perm) : m_perm(perm) {}
52 inline const type& perm() const { return m_perm; }
53 // this is a reference to perm
54 const type& m_perm;
55};
56
57template <typename ExpressionType, int Side, bool Transposed>
58struct permutation_matrix_product<ExpressionType, Side, Transposed, SparseShape> {
59 using MatrixType = typename nested_eval<ExpressionType, 1>::type;
60 using MatrixTypeCleaned = remove_all_t<MatrixType>;
61
62 using Scalar = typename MatrixTypeCleaned::Scalar;
63 using StorageIndex = typename MatrixTypeCleaned::StorageIndex;
64
65 // the actual "return type" is `Dest`. this is a temporary type
66 using ReturnType = SparseMatrix<Scalar, MatrixTypeCleaned::IsRowMajor ? RowMajor : ColMajor, StorageIndex>;
67 using TmpHelper = XprHelper<ExpressionType, ReturnType>;
68
69 static constexpr bool NeedOuterPermutation = ExpressionType::IsRowMajor ? Side == OnTheLeft : Side == OnTheRight;
70 static constexpr bool NeedInversePermutation = Transposed ? Side == OnTheLeft : Side == OnTheRight;
71
72 template <typename Dest, typename PermutationType>
73 static inline void permute_outer(Dest& dst, const PermutationType& perm, const ExpressionType& xpr) {
74 // if ExpressionType is not ReturnType, evaluate `xpr` (allocation)
75 // otherwise, just reference `xpr`
76 // TODO: handle trivial expressions such as CwiseBinaryOp without temporary
77 const TmpHelper tmpHelper(xpr);
78 const ReturnType& tmp = tmpHelper.xpr();
79
80 ReturnType result(tmp.rows(), tmp.cols());
81
82 for (Index j = 0; j < tmp.outerSize(); j++) {
83 Index jp = perm.indices().coeff(j);
84 Index jsrc = NeedInversePermutation ? jp : j;
85 Index jdst = NeedInversePermutation ? j : jp;
86 Index begin = tmp.outerIndexPtr()[jsrc];
87 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[jsrc + 1] : begin + tmp.innerNonZeroPtr()[jsrc];
88 result.outerIndexPtr()[jdst + 1] += end - begin;
89 }
90
91 std::partial_sum(result.outerIndexPtr(), result.outerIndexPtr() + result.outerSize() + 1, result.outerIndexPtr());
92 result.resizeNonZeros(result.nonZeros());
93
94 for (Index j = 0; j < tmp.outerSize(); j++) {
95 Index jp = perm.indices().coeff(j);
96 Index jsrc = NeedInversePermutation ? jp : j;
97 Index jdst = NeedInversePermutation ? j : jp;
98 Index begin = tmp.outerIndexPtr()[jsrc];
99 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[jsrc + 1] : begin + tmp.innerNonZeroPtr()[jsrc];
100 Index target = result.outerIndexPtr()[jdst];
101 smart_copy(tmp.innerIndexPtr() + begin, tmp.innerIndexPtr() + end, result.innerIndexPtr() + target);
102 smart_copy(tmp.valuePtr() + begin, tmp.valuePtr() + end, result.valuePtr() + target);
103 }
104 dst = std::move(result);
105 }
106
107 template <typename Dest, typename PermutationType>
108 static inline void permute_inner(Dest& dst, const PermutationType& perm, const ExpressionType& xpr) {
109 using InnerPermHelper = PermHelper<PermutationType, NeedInversePermutation>;
110 using InnerPermType = typename InnerPermHelper::type;
111
112 // if ExpressionType is not ReturnType, evaluate `xpr` (allocation)
113 // otherwise, just reference `xpr`
114 // TODO: handle trivial expressions such as CwiseBinaryOp without temporary
115 const TmpHelper tmpHelper(xpr);
116 const ReturnType& tmp = tmpHelper.xpr();
117
118 // if inverse permutation of inner indices is requested, calculate perm.inverse() (allocation)
119 // otherwise, just reference `perm`
120 const InnerPermHelper permHelper(perm);
121 const InnerPermType& innerPerm = permHelper.perm();
122
123 ReturnType result(tmp.rows(), tmp.cols());
124
125 for (Index j = 0; j < tmp.outerSize(); j++) {
126 Index begin = tmp.outerIndexPtr()[j];
127 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[j + 1] : begin + tmp.innerNonZeroPtr()[j];
128 result.outerIndexPtr()[j + 1] += end - begin;
129 }
130
131 std::partial_sum(result.outerIndexPtr(), result.outerIndexPtr() + result.outerSize() + 1, result.outerIndexPtr());
132 result.resizeNonZeros(result.nonZeros());
133
134 for (Index j = 0; j < tmp.outerSize(); j++) {
135 Index begin = tmp.outerIndexPtr()[j];
136 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[j + 1] : begin + tmp.innerNonZeroPtr()[j];
137 Index target = result.outerIndexPtr()[j];
138 std::transform(tmp.innerIndexPtr() + begin, tmp.innerIndexPtr() + end, result.innerIndexPtr() + target,
139 [&innerPerm](StorageIndex i) { return innerPerm.indices().coeff(i); });
140 smart_copy(tmp.valuePtr() + begin, tmp.valuePtr() + end, result.valuePtr() + target);
141 }
142 // the inner indices were permuted, and must be sorted
143 result.sortInnerIndices();
144 dst = std::move(result);
145 }
146
147 template <typename Dest, typename PermutationType, bool DoOuter = NeedOuterPermutation,
148 std::enable_if_t<DoOuter, int> = 0>
149 static inline void run(Dest& dst, const PermutationType& perm, const ExpressionType& xpr) {
150 permute_outer(dst, perm, xpr);
151 }
152
153 template <typename Dest, typename PermutationType, bool DoOuter = NeedOuterPermutation,
154 std::enable_if_t<!DoOuter, int> = 0>
155 static inline void run(Dest& dst, const PermutationType& perm, const ExpressionType& xpr) {
156 permute_inner(dst, perm, xpr);
157 }
158};
159
160} // namespace internal
161
162namespace internal {
163
164template <int ProductTag>
165struct product_promote_storage_type<Sparse, PermutationStorage, ProductTag> {
166 typedef Sparse ret;
167};
168template <int ProductTag>
169struct product_promote_storage_type<PermutationStorage, Sparse, ProductTag> {
170 typedef Sparse ret;
171};
172
173// TODO, the following two overloads are only needed to define the right temporary type through
174// typename traits<permutation_sparse_matrix_product<Rhs,Lhs,OnTheRight,false> >::ReturnType
175// whereas it should be correctly handled by traits<Product<> >::PlainObject
176
177template <typename Lhs, typename Rhs, int ProductTag>
178struct product_evaluator<Product<Lhs, Rhs, AliasFreeProduct>, ProductTag, PermutationShape, SparseShape>
179 : public evaluator<typename permutation_matrix_product<Rhs, OnTheLeft, false, SparseShape>::ReturnType> {
180 typedef Product<Lhs, Rhs, AliasFreeProduct> XprType;
181 typedef typename permutation_matrix_product<Rhs, OnTheLeft, false, SparseShape>::ReturnType PlainObject;
182 typedef evaluator<PlainObject> Base;
183
184 enum { Flags = Base::Flags | EvalBeforeNestingBit };
185
186 explicit product_evaluator(const XprType& xpr) : m_result(xpr.rows(), xpr.cols()) {
187 internal::construct_at<Base>(this, m_result);
188 generic_product_impl<Lhs, Rhs, PermutationShape, SparseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
189 }
190
191 protected:
192 PlainObject m_result;
193};
194
195template <typename Lhs, typename Rhs, int ProductTag>
196struct product_evaluator<Product<Lhs, Rhs, AliasFreeProduct>, ProductTag, SparseShape, PermutationShape>
197 : public evaluator<typename permutation_matrix_product<Lhs, OnTheRight, false, SparseShape>::ReturnType> {
198 typedef Product<Lhs, Rhs, AliasFreeProduct> XprType;
199 typedef typename permutation_matrix_product<Lhs, OnTheRight, false, SparseShape>::ReturnType PlainObject;
200 typedef evaluator<PlainObject> Base;
201
202 enum { Flags = Base::Flags | EvalBeforeNestingBit };
203
204 explicit product_evaluator(const XprType& xpr) : m_result(xpr.rows(), xpr.cols()) {
205 ::new (static_cast<Base*>(this)) Base(m_result);
206 generic_product_impl<Lhs, Rhs, SparseShape, PermutationShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
207 }
208
209 protected:
210 PlainObject m_result;
211};
212
213} // end namespace internal
214
217template <typename SparseDerived, typename PermDerived>
218inline const Product<SparseDerived, PermDerived, AliasFreeProduct> operator*(
219 const SparseMatrixBase<SparseDerived>& matrix, const PermutationBase<PermDerived>& perm) {
220 return Product<SparseDerived, PermDerived, AliasFreeProduct>(matrix.derived(), perm.derived());
221}
222
225template <typename SparseDerived, typename PermDerived>
230
233template <typename SparseDerived, typename PermutationType>
234inline const Product<SparseDerived, Inverse<PermutationType>, AliasFreeProduct> operator*(
235 const SparseMatrixBase<SparseDerived>& matrix, const InverseImpl<PermutationType, PermutationStorage>& tperm) {
236 return Product<SparseDerived, Inverse<PermutationType>, AliasFreeProduct>(matrix.derived(), tperm.derived());
237}
238
241template <typename SparseDerived, typename PermutationType>
242inline const Product<Inverse<PermutationType>, SparseDerived, AliasFreeProduct> operator*(
243 const InverseImpl<PermutationType, PermutationStorage>& tperm, const SparseMatrixBase<SparseDerived>& matrix) {
244 return Product<Inverse<PermutationType>, SparseDerived, AliasFreeProduct>(tperm.derived(), matrix.derived());
245}
246
247} // end namespace Eigen
248
249#endif // EIGEN_SPARSE_SELFADJOINTVIEW_H
Base class for permutations.
Definition PermutationMatrix.h:49
Expression of the product of two arbitrary matrices or vectors.
Definition Product.h:202
Base class of any sparse matrices or sparse expressions.
Definition SparseMatrixBase.h:30
@ OnTheLeft
Definition Constants.h:331
@ OnTheRight
Definition Constants.h:333
const unsigned int EvalBeforeNestingBit
Definition Constants.h:74
Namespace containing all symbols from the Eigen library.
Definition Core:137
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_inverse_op< typename Derived::Scalar >, const Derived > inverse(const Eigen::ArrayBase< Derived > &x)
Derived & derived()
Definition EigenBase.h:49