Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
TriangularSolver.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008 Gael Guennebaud <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SPARSETRIANGULARSOLVER_H
11#define EIGEN_SPARSETRIANGULARSOLVER_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20template <typename Lhs, typename Rhs, int Mode,
21 int UpLo = (Mode & Lower) ? Lower
22 : (Mode & Upper) ? Upper
23 : -1,
24 int StorageOrder = int(traits<Lhs>::Flags) & RowMajorBit>
25struct sparse_solve_triangular_selector;
26
27// forward substitution, row-major
28template <typename Lhs, typename Rhs, int Mode>
29struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, RowMajor> {
30 typedef typename Rhs::Scalar Scalar;
31 typedef evaluator<Lhs> LhsEval;
32 typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
33 static void run(const Lhs& lhs, Rhs& other) {
34 LhsEval lhsEval(lhs);
35 for (Index col = 0; col < other.cols(); ++col) {
36 for (Index i = 0; i < lhs.rows(); ++i) {
37 Scalar tmp = other.coeff(i, col);
38 Scalar lastVal(0);
39 Index lastIndex = 0;
40 for (LhsIterator it(lhsEval, i); it; ++it) {
41 lastVal = it.value();
42 lastIndex = it.index();
43 if (lastIndex == i) break;
44 tmp -= lastVal * other.coeff(lastIndex, col);
45 }
46 if (Mode & UnitDiag)
47 other.coeffRef(i, col) = tmp;
48 else {
49 eigen_assert(lastIndex == i);
50 other.coeffRef(i, col) = tmp / lastVal;
51 }
52 }
53 }
54 }
55};
56
57// backward substitution, row-major
58template <typename Lhs, typename Rhs, int Mode>
59struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, RowMajor> {
60 typedef typename Rhs::Scalar Scalar;
61 typedef evaluator<Lhs> LhsEval;
62 typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
63 static void run(const Lhs& lhs, Rhs& other) {
64 LhsEval lhsEval(lhs);
65 for (Index col = 0; col < other.cols(); ++col) {
66 for (Index i = lhs.rows() - 1; i >= 0; --i) {
67 Scalar tmp = other.coeff(i, col);
68 Scalar l_ii(0);
69 LhsIterator it(lhsEval, i);
70 while (it && it.index() < i) ++it;
71 if (!(Mode & UnitDiag)) {
72 eigen_assert(it && it.index() == i);
73 l_ii = it.value();
74 ++it;
75 } else if (it && it.index() == i)
76 ++it;
77 for (; it; ++it) {
78 tmp -= it.value() * other.coeff(it.index(), col);
79 }
80
81 if (Mode & UnitDiag)
82 other.coeffRef(i, col) = tmp;
83 else
84 other.coeffRef(i, col) = tmp / l_ii;
85 }
86 }
87 }
88};
89
90// forward substitution, col-major
91template <typename Lhs, typename Rhs, int Mode>
92struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, ColMajor> {
93 typedef typename Rhs::Scalar Scalar;
94 typedef evaluator<Lhs> LhsEval;
95 typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
96 static void run(const Lhs& lhs, Rhs& other) {
97 LhsEval lhsEval(lhs);
98 for (Index col = 0; col < other.cols(); ++col) {
99 for (Index i = 0; i < lhs.cols(); ++i) {
100 Scalar& tmp = other.coeffRef(i, col);
101 if (!numext::is_exactly_zero(tmp)) // optimization when other is actually sparse
102 {
103 LhsIterator it(lhsEval, i);
104 while (it && it.index() < i) ++it;
105 if (!(Mode & UnitDiag)) {
106 eigen_assert(it && it.index() == i);
107 tmp /= it.value();
108 }
109 if (it && it.index() == i) ++it;
110 for (; it; ++it) other.coeffRef(it.index(), col) -= tmp * it.value();
111 }
112 }
113 }
114 }
115};
116
117// backward substitution, col-major
118template <typename Lhs, typename Rhs, int Mode>
119struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, ColMajor> {
120 typedef typename Rhs::Scalar Scalar;
121 typedef evaluator<Lhs> LhsEval;
122 typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
123 static void run(const Lhs& lhs, Rhs& other) {
124 LhsEval lhsEval(lhs);
125 for (Index col = 0; col < other.cols(); ++col) {
126 for (Index i = lhs.cols() - 1; i >= 0; --i) {
127 Scalar& tmp = other.coeffRef(i, col);
128 if (!numext::is_exactly_zero(tmp)) // optimization when other is actually sparse
129 {
130 if (!(Mode & UnitDiag)) {
131 // TODO replace this by a binary search. make sure the binary search is safe for partially sorted elements
132 LhsIterator it(lhsEval, i);
133 while (it && it.index() != i) ++it;
134 eigen_assert(it && it.index() == i);
135 other.coeffRef(i, col) /= it.value();
136 }
137 LhsIterator it(lhsEval, i);
138 for (; it && it.index() < i; ++it) other.coeffRef(it.index(), col) -= tmp * it.value();
139 }
140 }
141 }
142 }
143};
144
145} // end namespace internal
146
147#ifndef EIGEN_PARSED_BY_DOXYGEN
148
149template <typename ExpressionType, unsigned int Mode>
150template <typename OtherDerived>
151void TriangularViewImpl<ExpressionType, Mode, Sparse>::solveInPlace(MatrixBase<OtherDerived>& other) const {
152 eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
153 eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper | Lower)));
154
155 enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
156
157 typedef std::conditional_t<copy, typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>
158 OtherCopy;
159 OtherCopy otherCopy(other.derived());
160
161 internal::sparse_solve_triangular_selector<ExpressionType, std::remove_reference_t<OtherCopy>, Mode>::run(
162 derived().nestedExpression(), otherCopy);
163
164 if (copy) other = otherCopy;
165}
166#endif
167
168// pure sparse path
169
170namespace internal {
171
172template <typename Lhs, typename Rhs, int Mode,
173 int UpLo = (Mode & Lower) ? Lower
174 : (Mode & Upper) ? Upper
175 : -1,
176 int StorageOrder = int(Lhs::Flags) & (RowMajorBit)>
177struct sparse_solve_triangular_sparse_selector;
178
179// forward substitution, col-major
180template <typename Lhs, typename Rhs, int Mode, int UpLo>
181struct sparse_solve_triangular_sparse_selector<Lhs, Rhs, Mode, UpLo, ColMajor> {
182 typedef typename Rhs::Scalar Scalar;
183 typedef typename promote_index_type<typename traits<Lhs>::StorageIndex, typename traits<Rhs>::StorageIndex>::type
184 StorageIndex;
185 static void run(const Lhs& lhs, Rhs& other) {
186 const bool IsLower = (UpLo == Lower);
187 AmbiVector<Scalar, StorageIndex> tempVector(other.rows() * 2);
188 tempVector.setBounds(0, other.rows());
189
190 Rhs res(other.rows(), other.cols());
191 res.reserve(other.nonZeros());
192
193 for (Index col = 0; col < other.cols(); ++col) {
194 // FIXME estimate number of non zeros
195 tempVector.init(.99 /*float(other.col(col).nonZeros())/float(other.rows())*/);
196 tempVector.setZero();
197 tempVector.restart();
198 for (typename Rhs::InnerIterator rhsIt(other, col); rhsIt; ++rhsIt) {
199 tempVector.coeffRef(rhsIt.index()) = rhsIt.value();
200 }
201
202 for (Index i = IsLower ? 0 : lhs.cols() - 1; IsLower ? i < lhs.cols() : i >= 0; i += IsLower ? 1 : -1) {
203 tempVector.restart();
204 Scalar& ci = tempVector.coeffRef(i);
205 if (!numext::is_exactly_zero(ci)) {
206 // find
207 typename Lhs::InnerIterator it(lhs, i);
208 if (!(Mode & UnitDiag)) {
209 if (IsLower) {
210 eigen_assert(it.index() == i);
211 ci /= it.value();
212 } else
213 ci /= lhs.coeff(i, i);
214 }
215 tempVector.restart();
216 if (IsLower) {
217 if (it.index() == i) ++it;
218 for (; it; ++it) tempVector.coeffRef(it.index()) -= ci * it.value();
219 } else {
220 for (; it && it.index() < i; ++it) tempVector.coeffRef(it.index()) -= ci * it.value();
221 }
222 }
223 }
224
225 // Index count = 0;
226 // FIXME compute a reference value to filter zeros
227 for (typename AmbiVector<Scalar, StorageIndex>::Iterator it(tempVector /*,1e-12*/); it; ++it) {
228 // ++ count;
229 // std::cerr << "fill " << it.index() << ", " << col << "\n";
230 // std::cout << it.value() << " ";
231 // FIXME use insertBack
232 res.insert(it.index(), col) = it.value();
233 }
234 // std::cout << "tempVector.nonZeros() == " << int(count) << " / " << (other.rows()) << "\n";
235 }
236 res.finalize();
237 other = res.markAsRValue();
238 }
239};
240
241} // end namespace internal
242
243#ifndef EIGEN_PARSED_BY_DOXYGEN
244template <typename ExpressionType, unsigned int Mode>
245template <typename OtherDerived>
246void TriangularViewImpl<ExpressionType, Mode, Sparse>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const {
247 eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
248 eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper | Lower)));
249
250 // enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
251
252 // typedef std::conditional_t<copy,
253 // typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&> OtherCopy;
254 // OtherCopy otherCopy(other.derived());
255
256 internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(
257 derived().nestedExpression(), other.derived());
258
259 // if (copy)
260 // other = otherCopy;
261}
262#endif
263
264} // end namespace Eigen
265
266#endif // EIGEN_SPARSETRIANGULARSOLVER_H
@ UnitDiag
Definition Constants.h:215
@ Lower
Definition Constants.h:211
@ Upper
Definition Constants.h:213
@ ColMajor
Definition Constants.h:318
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition Core:137