10#ifndef EIGEN_SPARSE_PERMUTATION_H
11#define EIGEN_SPARSE_PERMUTATION_H
16#include "./InternalHeaderCheck.h"
22template <
typename ExpressionType,
typename PlainObjectType,
23 bool NeedEval = !is_same<ExpressionType, PlainObjectType>::value>
25 XprHelper(
const ExpressionType& xpr) : m_xpr(xpr) {}
26 inline const PlainObjectType& xpr()
const {
return m_xpr; }
28 const PlainObjectType m_xpr;
30template <
typename ExpressionType,
typename PlainObjectType>
31struct XprHelper<ExpressionType, PlainObjectType, false> {
32 XprHelper(
const ExpressionType& xpr) : m_xpr(xpr) {}
33 inline const PlainObjectType& xpr()
const {
return m_xpr; }
35 const PlainObjectType& m_xpr;
38template <
typename PermDerived,
bool NeedInverseEval>
40 using IndicesType =
typename PermDerived::IndicesType;
41 using PermutationIndex =
typename IndicesType::Scalar;
42 using type = PermutationMatrix<IndicesType::SizeAtCompileTime, IndicesType::MaxSizeAtCompileTime, PermutationIndex>;
43 PermHelper(
const PermDerived& perm) : m_perm(perm.
inverse()) {}
44 inline const type& perm()
const {
return m_perm; }
48template <
typename PermDerived>
49struct PermHelper<PermDerived, false> {
50 using type = PermDerived;
51 PermHelper(
const PermDerived& perm) : m_perm(perm) {}
52 inline const type& perm()
const {
return m_perm; }
57template <
typename ExpressionType,
int S
ide,
bool Transposed>
58struct permutation_matrix_product<ExpressionType, Side, Transposed, SparseShape> {
59 using MatrixType =
typename nested_eval<ExpressionType, 1>::type;
60 using MatrixTypeCleaned = remove_all_t<MatrixType>;
62 using Scalar =
typename MatrixTypeCleaned::Scalar;
63 using StorageIndex =
typename MatrixTypeCleaned::StorageIndex;
66 using ReturnType = SparseMatrix<Scalar, MatrixTypeCleaned::IsRowMajor ? RowMajor : ColMajor, StorageIndex>;
67 using TmpHelper = XprHelper<ExpressionType, ReturnType>;
69 static constexpr bool NeedOuterPermutation = ExpressionType::IsRowMajor ? Side ==
OnTheLeft : Side ==
OnTheRight;
70 static constexpr bool NeedInversePermutation = Transposed ? Side ==
OnTheLeft : Side ==
OnTheRight;
72 template <
typename Dest,
typename PermutationType>
73 static inline void permute_outer(Dest& dst,
const PermutationType& perm,
const ExpressionType& xpr) {
77 const TmpHelper tmpHelper(xpr);
78 const ReturnType& tmp = tmpHelper.xpr();
80 ReturnType result(tmp.rows(), tmp.cols());
82 for (Index j = 0; j < tmp.outerSize(); j++) {
83 Index jp = perm.indices().coeff(j);
84 Index jsrc = NeedInversePermutation ? jp : j;
85 Index jdst = NeedInversePermutation ? j : jp;
86 Index begin = tmp.outerIndexPtr()[jsrc];
87 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[jsrc + 1] : begin + tmp.innerNonZeroPtr()[jsrc];
88 result.outerIndexPtr()[jdst + 1] += end - begin;
91 std::partial_sum(result.outerIndexPtr(), result.outerIndexPtr() + result.outerSize() + 1, result.outerIndexPtr());
92 result.resizeNonZeros(result.nonZeros());
94 for (Index j = 0; j < tmp.outerSize(); j++) {
95 Index jp = perm.indices().coeff(j);
96 Index jsrc = NeedInversePermutation ? jp : j;
97 Index jdst = NeedInversePermutation ? j : jp;
98 Index begin = tmp.outerIndexPtr()[jsrc];
99 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[jsrc + 1] : begin + tmp.innerNonZeroPtr()[jsrc];
100 Index target = result.outerIndexPtr()[jdst];
101 smart_copy(tmp.innerIndexPtr() + begin, tmp.innerIndexPtr() + end, result.innerIndexPtr() + target);
102 smart_copy(tmp.valuePtr() + begin, tmp.valuePtr() + end, result.valuePtr() + target);
104 dst = std::move(result);
107 template <
typename Dest,
typename PermutationType>
108 static inline void permute_inner(Dest& dst,
const PermutationType& perm,
const ExpressionType& xpr) {
109 using InnerPermHelper = PermHelper<PermutationType, NeedInversePermutation>;
110 using InnerPermType =
typename InnerPermHelper::type;
115 const TmpHelper tmpHelper(xpr);
116 const ReturnType& tmp = tmpHelper.xpr();
120 const InnerPermHelper permHelper(perm);
121 const InnerPermType& innerPerm = permHelper.perm();
123 ReturnType result(tmp.rows(), tmp.cols());
125 for (Index j = 0; j < tmp.outerSize(); j++) {
126 Index begin = tmp.outerIndexPtr()[j];
127 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[j + 1] : begin + tmp.innerNonZeroPtr()[j];
128 result.outerIndexPtr()[j + 1] += end - begin;
131 std::partial_sum(result.outerIndexPtr(), result.outerIndexPtr() + result.outerSize() + 1, result.outerIndexPtr());
132 result.resizeNonZeros(result.nonZeros());
134 for (Index j = 0; j < tmp.outerSize(); j++) {
135 Index begin = tmp.outerIndexPtr()[j];
136 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[j + 1] : begin + tmp.innerNonZeroPtr()[j];
137 Index target = result.outerIndexPtr()[j];
138 std::transform(tmp.innerIndexPtr() + begin, tmp.innerIndexPtr() + end, result.innerIndexPtr() + target,
139 [&innerPerm](StorageIndex i) { return innerPerm.indices().coeff(i); });
140 smart_copy(tmp.valuePtr() + begin, tmp.valuePtr() + end, result.valuePtr() + target);
143 result.sortInnerIndices();
144 dst = std::move(result);
147 template <
typename Dest,
typename PermutationType,
bool DoOuter = NeedOuterPermutation,
148 std::enable_if_t<DoOuter, int> = 0>
149 static inline void run(Dest& dst,
const PermutationType& perm,
const ExpressionType& xpr) {
150 permute_outer(dst, perm, xpr);
153 template <
typename Dest,
typename PermutationType,
bool DoOuter = NeedOuterPermutation,
154 std::enable_if_t<!DoOuter, int> = 0>
155 static inline void run(Dest& dst,
const PermutationType& perm,
const ExpressionType& xpr) {
156 permute_inner(dst, perm, xpr);
164template <
int ProductTag>
165struct product_promote_storage_type<Sparse, PermutationStorage, ProductTag> {
168template <
int ProductTag>
169struct product_promote_storage_type<PermutationStorage, Sparse, ProductTag> {
177template <
typename Lhs,
typename Rhs,
int ProductTag>
178struct product_evaluator<Product<Lhs, Rhs, AliasFreeProduct>, ProductTag, PermutationShape, SparseShape>
179 :
public evaluator<typename permutation_matrix_product<Rhs, OnTheLeft, false, SparseShape>::ReturnType> {
180 typedef Product<Lhs, Rhs, AliasFreeProduct> XprType;
181 typedef typename permutation_matrix_product<Rhs, OnTheLeft, false, SparseShape>::ReturnType PlainObject;
182 typedef evaluator<PlainObject> Base;
186 explicit product_evaluator(
const XprType& xpr) : m_result(xpr.rows(), xpr.cols()) {
187 internal::construct_at<Base>(
this, m_result);
188 generic_product_impl<Lhs, Rhs, PermutationShape, SparseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
192 PlainObject m_result;
195template <
typename Lhs,
typename Rhs,
int ProductTag>
196struct product_evaluator<Product<Lhs, Rhs, AliasFreeProduct>, ProductTag, SparseShape, PermutationShape>
197 :
public evaluator<typename permutation_matrix_product<Lhs, OnTheRight, false, SparseShape>::ReturnType> {
198 typedef Product<Lhs, Rhs, AliasFreeProduct> XprType;
199 typedef typename permutation_matrix_product<Lhs, OnTheRight, false, SparseShape>::ReturnType PlainObject;
200 typedef evaluator<PlainObject> Base;
204 explicit product_evaluator(
const XprType& xpr) : m_result(xpr.rows(), xpr.cols()) {
205 ::new (
static_cast<Base*
>(
this)) Base(m_result);
206 generic_product_impl<Lhs, Rhs, SparseShape, PermutationShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
210 PlainObject m_result;
217template <typename SparseDerived, typename PermDerived>
218inline const
Product<SparseDerived, PermDerived, AliasFreeProduct> operator*(
225template <
typename SparseDerived,
typename PermDerived>
233template <
typename SparseDerived,
typename PermutationType>
241template <
typename SparseDerived,
typename PermutationType>
Base class for permutations.
Definition PermutationMatrix.h:49
Expression of the product of two arbitrary matrices or vectors.
Definition Product.h:202
Base class of any sparse matrices or sparse expressions.
Definition SparseMatrixBase.h:30
@ OnTheLeft
Definition Constants.h:331
@ OnTheRight
Definition Constants.h:333
const unsigned int EvalBeforeNestingBit
Definition Constants.h:74
Namespace containing all symbols from the Eigen library.
Definition Core:137
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_inverse_op< typename Derived::Scalar >, const Derived > inverse(const Eigen::ArrayBase< Derived > &x)
Derived & derived()
Definition EigenBase.h:49