Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
SolveTriangular.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008-2009 Gael Guennebaud <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SOLVETRIANGULAR_H
11#define EIGEN_SOLVETRIANGULAR_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20// Forward declarations:
21// The following two routines are implemented in the products/TriangularSolver*.h files
22template <typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder>
23struct triangular_solve_vector;
24
25template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder,
26 int OtherStorageOrder, int OtherInnerStride>
27struct triangular_solve_matrix;
28
29// small helper struct extracting some traits on the underlying solver operation
30template <typename Lhs, typename Rhs, int Side>
31class trsolve_traits {
32 private:
33 enum { RhsIsVectorAtCompileTime = (Side == OnTheLeft ? Rhs::ColsAtCompileTime : Rhs::RowsAtCompileTime) == 1 };
34
35 public:
36 enum {
37 Unrolling = (RhsIsVectorAtCompileTime && Rhs::SizeAtCompileTime != Dynamic && Rhs::SizeAtCompileTime <= 8)
38 ? CompleteUnrolling
39 : NoUnrolling,
40 RhsVectors = RhsIsVectorAtCompileTime ? 1 : Dynamic
41 };
42};
43
44template <typename Lhs, typename Rhs,
45 int Side, // can be OnTheLeft/OnTheRight
46 int Mode, // can be Upper/Lower | UnitDiag
47 int Unrolling = trsolve_traits<Lhs, Rhs, Side>::Unrolling,
48 int RhsVectors = trsolve_traits<Lhs, Rhs, Side>::RhsVectors>
49struct triangular_solver_selector;
50
51template <typename Lhs, typename Rhs, int Side, int Mode>
52struct triangular_solver_selector<Lhs, Rhs, Side, Mode, NoUnrolling, 1> {
53 typedef typename Lhs::Scalar LhsScalar;
54 typedef typename Rhs::Scalar RhsScalar;
55 typedef blas_traits<Lhs> LhsProductTraits;
56 typedef typename LhsProductTraits::ExtractType ActualLhsType;
57 typedef Map<Matrix<RhsScalar, Dynamic, 1>, Aligned> MappedRhs;
58 static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs) {
59 ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
60
61 // FIXME find a way to allow an inner stride if packet_traits<Scalar>::size==1
62
63 bool useRhsDirectly = Rhs::InnerStrideAtCompileTime == 1 || rhs.innerStride() == 1;
64
65 ei_declare_aligned_stack_constructed_variable(RhsScalar, actualRhs, rhs.size(), (useRhsDirectly ? rhs.data() : 0));
66
67 if (!useRhsDirectly) MappedRhs(actualRhs, rhs.size()) = rhs;
68
69 triangular_solve_vector<LhsScalar, RhsScalar, Index, Side, Mode, LhsProductTraits::NeedToConjugate,
70 (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor>::run(actualLhs.cols(),
71 actualLhs.data(),
72 actualLhs.outerStride(),
73 actualRhs);
74
75 if (!useRhsDirectly) rhs = MappedRhs(actualRhs, rhs.size());
76 }
77};
78
79// the rhs is a matrix
80template <typename Lhs, typename Rhs, int Side, int Mode>
81struct triangular_solver_selector<Lhs, Rhs, Side, Mode, NoUnrolling, Dynamic> {
82 typedef typename Rhs::Scalar Scalar;
83 typedef blas_traits<Lhs> LhsProductTraits;
84 typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType;
85
86 static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs) {
87 add_const_on_value_type_t<ActualLhsType> actualLhs = LhsProductTraits::extract(lhs);
88
89 const Index size = lhs.rows();
90 const Index othersize = Side == OnTheLeft ? rhs.cols() : rhs.rows();
91
92 typedef internal::gemm_blocking_space<(Rhs::Flags & RowMajorBit) ? RowMajor : ColMajor, Scalar, Scalar,
93 Rhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime,
94 Lhs::MaxRowsAtCompileTime, 4>
95 BlockingType;
96
97 // Nothing to solve.
98 if (actualLhs.size() == 0 || rhs.size() == 0) {
99 return;
100 }
101
102 BlockingType blocking(rhs.rows(), rhs.cols(), size, 1, false);
103
104 triangular_solve_matrix<Scalar, Index, Side, Mode, LhsProductTraits::NeedToConjugate,
105 (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
106 (Rhs::Flags & RowMajorBit) ? RowMajor : ColMajor,
107 Rhs::InnerStrideAtCompileTime>::run(size, othersize, &actualLhs.coeffRef(0, 0),
108 actualLhs.outerStride(), &rhs.coeffRef(0, 0),
109 rhs.innerStride(), rhs.outerStride(), blocking);
110 }
111};
112
113/***************************************************************************
114 * meta-unrolling implementation
115 ***************************************************************************/
116
117template <typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size, bool Stop = LoopIndex == Size>
118struct triangular_solver_unroller;
119
120template <typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size>
121struct triangular_solver_unroller<Lhs, Rhs, Mode, LoopIndex, Size, false> {
122 enum {
123 IsLower = ((Mode & Lower) == Lower),
124 DiagIndex = IsLower ? LoopIndex : Size - LoopIndex - 1,
125 StartIndex = IsLower ? 0 : DiagIndex + 1
126 };
127 static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs) {
128 if (LoopIndex > 0)
129 rhs.coeffRef(DiagIndex) -= lhs.row(DiagIndex)
130 .template segment<LoopIndex>(StartIndex)
131 .transpose()
132 .cwiseProduct(rhs.template segment<LoopIndex>(StartIndex))
133 .sum();
134
135 if (!(Mode & UnitDiag)) rhs.coeffRef(DiagIndex) /= lhs.coeff(DiagIndex, DiagIndex);
136
137 triangular_solver_unroller<Lhs, Rhs, Mode, LoopIndex + 1, Size>::run(lhs, rhs);
138 }
139};
140
141template <typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size>
142struct triangular_solver_unroller<Lhs, Rhs, Mode, LoopIndex, Size, true> {
143 static EIGEN_DEVICE_FUNC void run(const Lhs&, Rhs&) {}
144};
145
146template <typename Lhs, typename Rhs, int Mode>
147struct triangular_solver_selector<Lhs, Rhs, OnTheLeft, Mode, CompleteUnrolling, 1> {
148 static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs) {
149 triangular_solver_unroller<Lhs, Rhs, Mode, 0, Rhs::SizeAtCompileTime>::run(lhs, rhs);
150 }
151};
152
153template <typename Lhs, typename Rhs, int Mode>
154struct triangular_solver_selector<Lhs, Rhs, OnTheRight, Mode, CompleteUnrolling, 1> {
155 static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, Rhs& rhs) {
156 Transpose<const Lhs> trLhs(lhs);
157 Transpose<Rhs> trRhs(rhs);
158
159 triangular_solver_unroller<Transpose<const Lhs>, Transpose<Rhs>,
160 ((Mode & Upper) == Upper ? Lower : Upper) | (Mode & UnitDiag), 0,
161 Rhs::SizeAtCompileTime>::run(trLhs, trRhs);
162 }
163};
164
165} // end namespace internal
166
167/***************************************************************************
168 * TriangularView methods
169 ***************************************************************************/
170
171#ifndef EIGEN_PARSED_BY_DOXYGEN
172template <typename MatrixType, unsigned int Mode>
173template <int Side, typename OtherDerived>
174EIGEN_DEVICE_FUNC void TriangularViewImpl<MatrixType, Mode, Dense>::solveInPlace(
175 const MatrixBase<OtherDerived>& _other) const {
176 OtherDerived& other = _other.const_cast_derived();
177 eigen_assert(derived().cols() == derived().rows() && ((Side == OnTheLeft && derived().cols() == other.rows()) ||
178 (Side == OnTheRight && derived().cols() == other.cols())));
179 eigen_assert((!(int(Mode) & int(ZeroDiag))) && bool(int(Mode) & (int(Upper) | int(Lower))));
180 // If solving for a 0x0 matrix, nothing to do, simply return.
181 if (derived().cols() == 0) return;
182
183 enum {
184 copy = (internal::traits<OtherDerived>::Flags & RowMajorBit) && OtherDerived::IsVectorAtCompileTime &&
185 OtherDerived::SizeAtCompileTime != 1
186 };
187 typedef std::conditional_t<copy, typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>
188 OtherCopy;
189 OtherCopy otherCopy(other);
190
191 internal::triangular_solver_selector<MatrixType, std::remove_reference_t<OtherCopy>, Side, Mode>::run(
192 derived().nestedExpression(), otherCopy);
193
194 if (copy) other = otherCopy;
195}
196
197template <typename Derived, unsigned int Mode>
198template <int Side, typename Other>
199const internal::triangular_solve_retval<Side, TriangularView<Derived, Mode>, Other>
200TriangularViewImpl<Derived, Mode, Dense>::solve(const MatrixBase<Other>& other) const {
201 return internal::triangular_solve_retval<Side, TriangularViewType, Other>(derived(), other.derived());
202}
203#endif
204
205namespace internal {
206
207template <int Side, typename TriangularType, typename Rhs>
208struct traits<triangular_solve_retval<Side, TriangularType, Rhs> > {
209 typedef typename internal::plain_matrix_type_column_major<Rhs>::type ReturnType;
210};
211
212template <int Side, typename TriangularType, typename Rhs>
213struct triangular_solve_retval : public ReturnByValue<triangular_solve_retval<Side, TriangularType, Rhs> > {
214 typedef remove_all_t<typename Rhs::Nested> RhsNestedCleaned;
215 typedef ReturnByValue<triangular_solve_retval> Base;
216
217 triangular_solve_retval(const TriangularType& tri, const Rhs& rhs) : m_triangularMatrix(tri), m_rhs(rhs) {}
218
219 inline EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_rhs.rows(); }
220 inline EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_rhs.cols(); }
221
222 template <typename Dest>
223 inline void evalTo(Dest& dst) const {
224 if (!is_same_dense(dst, m_rhs)) dst = m_rhs;
225 m_triangularMatrix.template solveInPlace<Side>(dst);
226 }
227
228 protected:
229 const TriangularType& m_triangularMatrix;
230 typename Rhs::Nested m_rhs;
231};
232
233} // namespace internal
234
235} // end namespace Eigen
236
237#endif // EIGEN_SOLVETRIANGULAR_H
@ UnitDiag
Definition Constants.h:215
@ ZeroDiag
Definition Constants.h:217
@ Lower
Definition Constants.h:211
@ Upper
Definition Constants.h:213
@ Aligned
Definition Constants.h:242
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
@ OnTheLeft
Definition Constants.h:331
@ OnTheRight
Definition Constants.h:333
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition Core:137
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:83
const int Dynamic
Definition Constants.h:25