Eigen  3.4.90 (git rev 5a9f66fb35d03a4da9ef8976e67a61b30aa16dcf)
 
Loading...
Searching...
No Matches
ConjugateGradient.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2011-2014 Gael Guennebaud <[email protected]>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CONJUGATE_GRADIENT_H
11#define EIGEN_CONJUGATE_GRADIENT_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
29template <typename MatrixType, typename Rhs, typename Dest, typename Preconditioner>
30EIGEN_DONT_INLINE void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x, const Preconditioner& precond,
31 Index& iters, typename Dest::RealScalar& tol_error) {
32 typedef typename Dest::RealScalar RealScalar;
33 typedef typename Dest::Scalar Scalar;
34 typedef Matrix<Scalar, Dynamic, 1> VectorType;
35
36 RealScalar tol = tol_error;
37 Index maxIters = iters;
38
39 Index n = mat.cols();
40
41 VectorType residual = rhs - mat * x; // initial residual
42
43 RealScalar rhsNorm2 = rhs.squaredNorm();
44 if (rhsNorm2 == 0) {
45 x.setZero();
46 iters = 0;
47 tol_error = 0;
48 return;
49 }
50 const RealScalar considerAsZero = (std::numeric_limits<RealScalar>::min)();
51 RealScalar threshold = numext::maxi(RealScalar(tol * tol * rhsNorm2), considerAsZero);
52 RealScalar residualNorm2 = residual.squaredNorm();
53 if (residualNorm2 < threshold) {
54 iters = 0;
55 tol_error = numext::sqrt(residualNorm2 / rhsNorm2);
56 return;
57 }
58
59 VectorType p(n);
60 p = precond.solve(residual); // initial search direction
61
62 VectorType z(n), tmp(n);
63 RealScalar absNew = numext::real(residual.dot(p)); // the square of the absolute value of r scaled by invM
64 Index i = 0;
65 while (i < maxIters) {
66 tmp.noalias() = mat * p; // the bottleneck of the algorithm
67
68 Scalar alpha = absNew / p.dot(tmp); // the amount we travel on dir
69 x += alpha * p; // update solution
70 residual -= alpha * tmp; // update residual
71
72 residualNorm2 = residual.squaredNorm();
73 if (residualNorm2 < threshold) break;
74
75 z = precond.solve(residual); // approximately solve for "A z = residual"
76
77 RealScalar absOld = absNew;
78 absNew = numext::real(residual.dot(z)); // update the absolute value of r
79 RealScalar beta = absNew / absOld; // calculate the Gram-Schmidt value used to create the new search direction
80 p = z + beta * p; // update search direction
81 i++;
82 }
83 tol_error = numext::sqrt(residualNorm2 / rhsNorm2);
84 iters = i;
85}
86
87} // namespace internal
88
89template <typename MatrixType_, int UpLo_ = Lower,
90 typename Preconditioner_ = DiagonalPreconditioner<typename MatrixType_::Scalar> >
91class ConjugateGradient;
92
93namespace internal {
94
95template <typename MatrixType_, int UpLo_, typename Preconditioner_>
96struct traits<ConjugateGradient<MatrixType_, UpLo_, Preconditioner_> > {
97 typedef MatrixType_ MatrixType;
98 typedef Preconditioner_ Preconditioner;
99};
100
101} // namespace internal
102
151template <typename MatrixType_, int UpLo_, typename Preconditioner_>
152class ConjugateGradient : public IterativeSolverBase<ConjugateGradient<MatrixType_, UpLo_, Preconditioner_> > {
154 using Base::m_error;
155 using Base::m_info;
156 using Base::m_isInitialized;
157 using Base::m_iterations;
158 using Base::matrix;
159
160 public:
161 typedef MatrixType_ MatrixType;
162 typedef typename MatrixType::Scalar Scalar;
163 typedef typename MatrixType::RealScalar RealScalar;
164 typedef Preconditioner_ Preconditioner;
165
166 enum { UpLo = UpLo_ };
167
168 public:
171
182 template <typename MatrixDerived>
183 explicit ConjugateGradient(const EigenBase<MatrixDerived>& A) : Base(A.derived()) {}
184
186
188 template <typename Rhs, typename Dest>
189 void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const {
190 typedef typename Base::MatrixWrapper MatrixWrapper;
191 typedef typename Base::ActualMatrixType ActualMatrixType;
192 enum {
193 TransposeInput = (!MatrixWrapper::MatrixFree) && (UpLo == (Lower | Upper)) && (!MatrixType::IsRowMajor) &&
194 (!NumTraits<Scalar>::IsComplex)
195 };
196 typedef std::conditional_t<TransposeInput, Transpose<const ActualMatrixType>, ActualMatrixType const&>
197 RowMajorWrapper;
198 EIGEN_STATIC_ASSERT(internal::check_implication(MatrixWrapper::MatrixFree, UpLo == (Lower | Upper)),
199 MATRIX_FREE_CONJUGATE_GRADIENT_IS_COMPATIBLE_WITH_UPPER_UNION_LOWER_MODE_ONLY);
200 typedef std::conditional_t<UpLo == (Lower | Upper), RowMajorWrapper,
201 typename MatrixWrapper::template ConstSelfAdjointViewReturnType<UpLo>::Type>
202 SelfAdjointWrapper;
203
204 m_iterations = Base::maxIterations();
205 m_error = Base::m_tolerance;
206
207 RowMajorWrapper row_mat(matrix());
208 internal::conjugate_gradient(SelfAdjointWrapper(row_mat), b, x, Base::m_preconditioner, m_iterations, m_error);
209 m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
210 }
211
212 protected:
213};
214
215} // end namespace Eigen
216
217#endif // EIGEN_CONJUGATE_GRADIENT_H
A conjugate gradient solver for sparse (or dense) self-adjoint problems.
Definition ConjugateGradient.h:152
ConjugateGradient(const EigenBase< MatrixDerived > &A)
Definition ConjugateGradient.h:183
ConjugateGradient()
Definition ConjugateGradient.h:170
Base class for linear iterative solvers.
Definition IterativeSolverBase.h:124
Index maxIterations() const
Definition IterativeSolverBase.h:251
@ Lower
Definition Constants.h:211
@ Upper
Definition Constants.h:213
@ Success
Definition Constants.h:440
@ NoConvergence
Definition Constants.h:444
Namespace containing all symbols from the Eigen library.
Definition Core:137
Definition EigenBase.h:33