10#ifndef EIGEN_SOLVETRIANGULAR_H
11#define EIGEN_SOLVETRIANGULAR_H
14#include "./InternalHeaderCheck.h"
22template <
typename LhsScalar,
typename RhsScalar,
typename Index,
int S
ide,
int Mode,
bool Conjugate,
int StorageOrder>
23struct triangular_solve_vector;
25template <
typename Scalar,
typename Index,
int Side,
int Mode,
bool Conjugate,
int TriStorageOrder,
26 int OtherStorageOrder,
int OtherInnerStride>
27struct triangular_solve_matrix;
30template <
typename Lhs,
typename Rhs,
int S
ide>
33 enum { RhsIsVectorAtCompileTime = (Side ==
OnTheLeft ? Rhs::ColsAtCompileTime : Rhs::RowsAtCompileTime) == 1 };
37 Unrolling = (RhsIsVectorAtCompileTime && Rhs::SizeAtCompileTime !=
Dynamic && Rhs::SizeAtCompileTime <= 8)
40 RhsVectors = RhsIsVectorAtCompileTime ? 1 :
Dynamic
44template <
typename Lhs,
typename Rhs,
47 int Unrolling = trsolve_traits<Lhs, Rhs, Side>::Unrolling,
48 int RhsVectors = trsolve_traits<Lhs, Rhs, Side>::RhsVectors>
49struct triangular_solver_selector;
51template <
typename Lhs,
typename Rhs,
int S
ide,
int Mode>
52struct triangular_solver_selector<Lhs, Rhs, Side, Mode, NoUnrolling, 1> {
53 typedef typename Lhs::Scalar LhsScalar;
54 typedef typename Rhs::Scalar RhsScalar;
55 typedef blas_traits<Lhs> LhsProductTraits;
56 typedef typename LhsProductTraits::ExtractType ActualLhsType;
57 typedef Map<Matrix<RhsScalar, Dynamic, 1>,
Aligned> MappedRhs;
58 static EIGEN_DEVICE_FUNC
void run(
const Lhs& lhs, Rhs& rhs) {
59 ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
63 bool useRhsDirectly = Rhs::InnerStrideAtCompileTime == 1 || rhs.innerStride() == 1;
65 ei_declare_aligned_stack_constructed_variable(RhsScalar, actualRhs, rhs.size(), (useRhsDirectly ? rhs.data() : 0));
67 if (!useRhsDirectly) MappedRhs(actualRhs, rhs.size()) = rhs;
69 triangular_solve_vector<LhsScalar, RhsScalar, Index, Side, Mode, LhsProductTraits::NeedToConjugate,
72 actualLhs.outerStride(),
75 if (!useRhsDirectly) rhs = MappedRhs(actualRhs, rhs.size());
80template <
typename Lhs,
typename Rhs,
int S
ide,
int Mode>
81struct triangular_solver_selector<Lhs, Rhs, Side, Mode, NoUnrolling,
Dynamic> {
82 typedef typename Rhs::Scalar Scalar;
83 typedef blas_traits<Lhs> LhsProductTraits;
84 typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType;
86 static EIGEN_DEVICE_FUNC
void run(
const Lhs& lhs, Rhs& rhs) {
87 add_const_on_value_type_t<ActualLhsType> actualLhs = LhsProductTraits::extract(lhs);
89 const Index size = lhs.rows();
90 const Index othersize = Side ==
OnTheLeft ? rhs.cols() : rhs.rows();
92 typedef internal::gemm_blocking_space<(Rhs::Flags &
RowMajorBit) ? RowMajor :
ColMajor, Scalar, Scalar,
93 Rhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime,
94 Lhs::MaxRowsAtCompileTime, 4>
98 if (actualLhs.size() == 0 || rhs.size() == 0) {
102 BlockingType blocking(rhs.rows(), rhs.cols(), size, 1,
false);
104 triangular_solve_matrix<Scalar, Index, Side, Mode, LhsProductTraits::NeedToConjugate,
107 Rhs::InnerStrideAtCompileTime>::run(size, othersize, &actualLhs.coeffRef(0, 0),
108 actualLhs.outerStride(), &rhs.coeffRef(0, 0),
109 rhs.innerStride(), rhs.outerStride(), blocking);
117template <
typename Lhs,
typename Rhs,
int Mode,
int LoopIndex,
int Size,
bool Stop = LoopIndex == Size>
118struct triangular_solver_unroller;
120template <
typename Lhs,
typename Rhs,
int Mode,
int LoopIndex,
int Size>
121struct triangular_solver_unroller<Lhs, Rhs, Mode, LoopIndex, Size, false> {
123 IsLower = ((Mode &
Lower) == Lower),
124 DiagIndex = IsLower ? LoopIndex : Size - LoopIndex - 1,
125 StartIndex = IsLower ? 0 : DiagIndex + 1
127 static EIGEN_DEVICE_FUNC
void run(
const Lhs& lhs, Rhs& rhs) {
129 rhs.coeffRef(DiagIndex) -= lhs.row(DiagIndex)
130 .template segment<LoopIndex>(StartIndex)
132 .cwiseProduct(rhs.template segment<LoopIndex>(StartIndex))
135 if (!(Mode & UnitDiag)) rhs.coeffRef(DiagIndex) /= lhs.coeff(DiagIndex, DiagIndex);
137 triangular_solver_unroller<Lhs, Rhs, Mode, LoopIndex + 1, Size>::run(lhs, rhs);
141template <
typename Lhs,
typename Rhs,
int Mode,
int LoopIndex,
int Size>
142struct triangular_solver_unroller<Lhs, Rhs, Mode, LoopIndex, Size, true> {
143 static EIGEN_DEVICE_FUNC
void run(
const Lhs&, Rhs&) {}
146template <
typename Lhs,
typename Rhs,
int Mode>
147struct triangular_solver_selector<Lhs, Rhs,
OnTheLeft, Mode, CompleteUnrolling, 1> {
148 static EIGEN_DEVICE_FUNC
void run(
const Lhs& lhs, Rhs& rhs) {
149 triangular_solver_unroller<Lhs, Rhs, Mode, 0, Rhs::SizeAtCompileTime>::run(lhs, rhs);
153template <
typename Lhs,
typename Rhs,
int Mode>
154struct triangular_solver_selector<Lhs, Rhs,
OnTheRight, Mode, CompleteUnrolling, 1> {
155 static EIGEN_DEVICE_FUNC
void run(
const Lhs& lhs, Rhs& rhs) {
156 Transpose<const Lhs> trLhs(lhs);
157 Transpose<Rhs> trRhs(rhs);
159 triangular_solver_unroller<Transpose<const Lhs>, Transpose<Rhs>,
161 Rhs::SizeAtCompileTime>::run(trLhs, trRhs);
171#ifndef EIGEN_PARSED_BY_DOXYGEN
172template <
typename MatrixType,
unsigned int Mode>
173template <
int S
ide,
typename OtherDerived>
174EIGEN_DEVICE_FUNC
void TriangularViewImpl<MatrixType, Mode, Dense>::solveInPlace(
175 const MatrixBase<OtherDerived>& _other)
const {
176 OtherDerived& other = _other.const_cast_derived();
177 eigen_assert(derived().cols() == derived().rows() && ((Side ==
OnTheLeft && derived().cols() == other.rows()) ||
178 (Side ==
OnTheRight && derived().cols() == other.cols())));
179 eigen_assert((!(
int(Mode) &
int(
ZeroDiag))) &&
bool(
int(Mode) & (
int(
Upper) |
int(
Lower))));
181 if (derived().cols() == 0)
return;
184 copy = (internal::traits<OtherDerived>::Flags &
RowMajorBit) && OtherDerived::IsVectorAtCompileTime &&
185 OtherDerived::SizeAtCompileTime != 1
187 typedef std::conditional_t<copy, typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>
189 OtherCopy otherCopy(other);
191 internal::triangular_solver_selector<MatrixType, std::remove_reference_t<OtherCopy>, Side, Mode>::run(
192 derived().nestedExpression(), otherCopy);
194 if (copy) other = otherCopy;
197template <
typename Derived,
unsigned int Mode>
198template <
int S
ide,
typename Other>
199const internal::triangular_solve_retval<Side, TriangularView<Derived, Mode>, Other>
200TriangularViewImpl<Derived, Mode, Dense>::solve(
const MatrixBase<Other>& other)
const {
201 return internal::triangular_solve_retval<Side, TriangularViewType, Other>(derived(), other.derived());
207template <
int S
ide,
typename TriangularType,
typename Rhs>
208struct traits<triangular_solve_retval<Side, TriangularType, Rhs> > {
209 typedef typename internal::plain_matrix_type_column_major<Rhs>::type ReturnType;
212template <
int S
ide,
typename TriangularType,
typename Rhs>
213struct triangular_solve_retval :
public ReturnByValue<triangular_solve_retval<Side, TriangularType, Rhs> > {
214 typedef remove_all_t<typename Rhs::Nested> RhsNestedCleaned;
215 typedef ReturnByValue<triangular_solve_retval> Base;
217 triangular_solve_retval(
const TriangularType& tri,
const Rhs& rhs) : m_triangularMatrix(tri), m_rhs(rhs) {}
219 inline EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT {
return m_rhs.rows(); }
220 inline EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT {
return m_rhs.cols(); }
222 template <
typename Dest>
223 inline void evalTo(Dest& dst)
const {
224 if (!is_same_dense(dst, m_rhs)) dst = m_rhs;
225 m_triangularMatrix.template solveInPlace<Side>(dst);
229 const TriangularType& m_triangularMatrix;
230 typename Rhs::Nested m_rhs;
@ UnitDiag
Definition Constants.h:215
@ ZeroDiag
Definition Constants.h:217
@ Lower
Definition Constants.h:211
@ Upper
Definition Constants.h:213
@ Aligned
Definition Constants.h:242
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
@ OnTheLeft
Definition Constants.h:331
@ OnTheRight
Definition Constants.h:333
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition Core:137
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:83
const int Dynamic
Definition Constants.h:25