10#ifndef EIGEN_PASTIXSUPPORT_H
11#define EIGEN_PASTIXSUPPORT_H
14#include "./InternalHeaderCheck.h"
19#define PASTIX_COMPLEX COMPLEX
20#define PASTIX_DCOMPLEX DCOMPLEX
22#define PASTIX_COMPLEX std::complex<float>
23#define PASTIX_DCOMPLEX std::complex<double>
34template <
typename MatrixType_,
bool IsStrSym = false>
36template <
typename MatrixType_,
int Options>
38template <
typename MatrixType_,
int Options>
43template <
class Pastix>
46template <
typename MatrixType_>
47struct pastix_traits<PastixLU<MatrixType_> > {
48 typedef MatrixType_ MatrixType;
49 typedef typename MatrixType_::Scalar Scalar;
50 typedef typename MatrixType_::RealScalar RealScalar;
51 typedef typename MatrixType_::StorageIndex StorageIndex;
54template <
typename MatrixType_,
int Options>
55struct pastix_traits<PastixLLT<MatrixType_, Options> > {
56 typedef MatrixType_ MatrixType;
57 typedef typename MatrixType_::Scalar Scalar;
58 typedef typename MatrixType_::RealScalar RealScalar;
59 typedef typename MatrixType_::StorageIndex StorageIndex;
62template <
typename MatrixType_,
int Options>
63struct pastix_traits<PastixLDLT<MatrixType_, Options> > {
64 typedef MatrixType_ MatrixType;
65 typedef typename MatrixType_::Scalar Scalar;
66 typedef typename MatrixType_::RealScalar RealScalar;
67 typedef typename MatrixType_::StorageIndex StorageIndex;
70inline void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx,
float *vals,
71 int *perm,
int *invp,
float *x,
int nbrhs,
int *iparm,
double *dparm) {
81 s_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
84inline void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx,
double *vals,
85 int *perm,
int *invp,
double *x,
int nbrhs,
int *iparm,
double *dparm) {
95 d_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
98inline void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx,
99 std::complex<float> *vals,
int *perm,
int *invp, std::complex<float> *x,
int nbrhs,
int *iparm,
110 c_pastix(pastix_data, pastix_comm, n, ptr, idx,
reinterpret_cast<PASTIX_COMPLEX *
>(vals), perm, invp,
111 reinterpret_cast<PASTIX_COMPLEX *
>(x), nbrhs, iparm, dparm);
114inline void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx,
115 std::complex<double> *vals,
int *perm,
int *invp, std::complex<double> *x,
int nbrhs,
116 int *iparm,
double *dparm) {
126 z_pastix(pastix_data, pastix_comm, n, ptr, idx,
reinterpret_cast<PASTIX_DCOMPLEX *
>(vals), perm, invp,
127 reinterpret_cast<PASTIX_DCOMPLEX *
>(x), nbrhs, iparm, dparm);
131template <
typename MatrixType>
132void c_to_fortran_numbering(MatrixType &mat) {
133 if (!(mat.outerIndexPtr()[0])) {
135 for (i = 0; i <= mat.rows(); ++i) ++mat.outerIndexPtr()[i];
136 for (i = 0; i < mat.nonZeros(); ++i) ++mat.innerIndexPtr()[i];
141template <
typename MatrixType>
142void fortran_to_c_numbering(MatrixType &mat) {
144 if (mat.outerIndexPtr()[0] == 1) {
146 for (i = 0; i <= mat.rows(); ++i) --mat.outerIndexPtr()[i];
147 for (i = 0; i < mat.nonZeros(); ++i) --mat.innerIndexPtr()[i];
154template <
class Derived>
155class PastixBase :
public SparseSolverBase<Derived> {
159 using Base::m_isInitialized;
162 using Base::_solve_impl;
164 typedef typename internal::pastix_traits<Derived>::MatrixType MatrixType_;
165 typedef MatrixType_ MatrixType;
166 typedef typename MatrixType::Scalar Scalar;
167 typedef typename MatrixType::RealScalar RealScalar;
168 typedef typename MatrixType::StorageIndex StorageIndex;
169 typedef Matrix<Scalar, Dynamic, 1> Vector;
170 typedef SparseMatrix<Scalar, ColMajor> ColSpMatrix;
171 enum { ColsAtCompileTime = MatrixType::ColsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime };
174 PastixBase() : m_initisOk(false), m_analysisIsOk(false), m_factorizationIsOk(false), m_pastixdata(0), m_size(0) {
178 ~PastixBase() { clean(); }
180 template <
typename Rhs,
typename Dest>
181 bool _solve_impl(
const MatrixBase<Rhs> &b, MatrixBase<Dest> &x)
const;
188 Array<StorageIndex, IPARM_SIZE, 1> &iparm() {
return m_iparm; }
194 int &iparm(
int idxparam) {
return m_iparm(idxparam); }
200 Array<double, DPARM_SIZE, 1> &dparm() {
return m_dparm; }
205 double &dparm(
int idxparam) {
return m_dparm(idxparam); }
207 inline Index cols()
const {
return m_size; }
208 inline Index rows()
const {
return m_size; }
218 ComputationInfo info()
const {
219 eigen_assert(m_isInitialized &&
"Decomposition is not initialized.");
228 void analyzePattern(ColSpMatrix &mat);
231 void factorize(ColSpMatrix &mat);
235 eigen_assert(m_initisOk &&
"The Pastix structure should be allocated first");
236 m_iparm(IPARM_START_TASK) = API_TASK_CLEAN;
237 m_iparm(IPARM_END_TASK) = API_TASK_CLEAN;
238 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar *)0, m_perm.
data(), m_invp.
data(), 0, 0,
242 void compute(ColSpMatrix &mat);
246 int m_factorizationIsOk;
247 mutable ComputationInfo m_info;
248 mutable pastix_data_t *m_pastixdata;
250 mutable Array<int, IPARM_SIZE, 1> m_iparm;
251 mutable Array<double, DPARM_SIZE, 1> m_dparm;
252 mutable Matrix<StorageIndex, Dynamic, 1> m_perm;
253 mutable Matrix<StorageIndex, Dynamic, 1> m_invp;
261template <
class Derived>
262void PastixBase<Derived>::init() {
264 m_iparm.setZero(IPARM_SIZE);
265 m_dparm.setZero(DPARM_SIZE);
267 m_iparm(IPARM_MODIFY_PARAMETER) = API_NO;
268 pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, 0, 0, 0, 0, 1, m_iparm.data(), m_dparm.data());
270 m_iparm[IPARM_MATRIX_VERIFICATION] = API_NO;
271 m_iparm[IPARM_VERBOSE] = API_VERBOSE_NOT;
272 m_iparm[IPARM_ORDERING] = API_ORDER_SCOTCH;
273 m_iparm[IPARM_INCOMPLETE] = API_NO;
274 m_iparm[IPARM_OOC_LIMIT] = 2000;
275 m_iparm[IPARM_RHS_MAKING] = API_RHS_B;
276 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
278 m_iparm(IPARM_START_TASK) = API_TASK_INIT;
279 m_iparm(IPARM_END_TASK) = API_TASK_INIT;
280 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar *)0, 0, 0, 0, 0, m_iparm.data(),
284 if (m_iparm(IPARM_ERROR_NUMBER)) {
293template <
class Derived>
294void PastixBase<Derived>::compute(ColSpMatrix &mat) {
295 eigen_assert(mat.rows() == mat.cols() &&
"The input matrix should be squared");
300 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
303template <
class Derived>
304void PastixBase<Derived>::analyzePattern(ColSpMatrix &mat) {
305 eigen_assert(m_initisOk &&
"The initialization of PaSTiX failed");
308 if (m_size > 0) clean();
310 m_size = internal::convert_index<int>(mat.rows());
311 m_perm.resize(m_size);
312 m_invp.resize(m_size);
314 m_iparm(IPARM_START_TASK) = API_TASK_ORDERING;
315 m_iparm(IPARM_END_TASK) = API_TASK_ANALYSE;
316 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
317 mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
320 if (m_iparm(IPARM_ERROR_NUMBER)) {
322 m_analysisIsOk =
false;
325 m_analysisIsOk =
true;
329template <
class Derived>
330void PastixBase<Derived>::factorize(ColSpMatrix &mat) {
332 eigen_assert(m_analysisIsOk &&
"The analysis phase should be called before the factorization phase");
333 m_iparm(IPARM_START_TASK) = API_TASK_NUMFACT;
334 m_iparm(IPARM_END_TASK) = API_TASK_NUMFACT;
335 m_size = internal::convert_index<int>(mat.rows());
337 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
338 mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
341 if (m_iparm(IPARM_ERROR_NUMBER)) {
343 m_factorizationIsOk =
false;
344 m_isInitialized =
false;
347 m_factorizationIsOk =
true;
348 m_isInitialized =
true;
353template <
typename Base>
354template <
typename Rhs,
typename Dest>
355bool PastixBase<Base>::_solve_impl(
const MatrixBase<Rhs> &b, MatrixBase<Dest> &x)
const {
356 eigen_assert(m_isInitialized &&
"The matrix should be factorized first");
357 EIGEN_STATIC_ASSERT((Dest::Flags & RowMajorBit) == 0, THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
362 for (
int i = 0; i < b.cols(); i++) {
363 m_iparm[IPARM_START_TASK] = API_TASK_SOLVE;
364 m_iparm[IPARM_END_TASK] = API_TASK_REFINE;
366 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, internal::convert_index<int>(x.rows()), 0, 0, 0,
367 m_perm.data(), m_invp.data(), &x(0, i), rhs, m_iparm.data(), m_dparm.data());
373 return m_iparm(IPARM_ERROR_NUMBER) == 0;
397template <
typename MatrixType_,
bool IsStrSym>
398class PastixLU :
public PastixBase<PastixLU<MatrixType_> > {
400 typedef MatrixType_ MatrixType;
401 typedef PastixBase<PastixLU<MatrixType> > Base;
403 typedef typename MatrixType::StorageIndex StorageIndex;
408 explicit PastixLU(
const MatrixType &matrix) : Base() {
418 m_structureIsUptodate =
false;
420 grabMatrix(matrix, temp);
429 m_structureIsUptodate =
false;
431 grabMatrix(matrix, temp);
432 Base::analyzePattern(temp);
442 grabMatrix(matrix, temp);
443 Base::factorize(temp);
448 m_structureIsUptodate =
false;
449 m_iparm(IPARM_SYM) = API_SYM_NO;
450 m_iparm(IPARM_FACTORIZATION) = API_FACT_LU;
453 void grabMatrix(
const MatrixType &matrix, ColSpMatrix &out) {
457 if (!m_structureIsUptodate) {
459 m_transposedStructure = matrix.transpose();
462 for (Index j = 0; j < m_transposedStructure.
outerSize(); ++j)
463 for (
typename ColSpMatrix::InnerIterator it(m_transposedStructure, j); it; ++it) it.valueRef() = 0.0;
465 m_structureIsUptodate =
true;
468 out = m_transposedStructure + matrix;
470 internal::c_to_fortran_numbering(out);
476 ColSpMatrix m_transposedStructure;
477 bool m_structureIsUptodate;
496template <
typename MatrixType_,
int UpLo_>
497class PastixLLT :
public PastixBase<PastixLLT<MatrixType_, UpLo_> > {
499 typedef MatrixType_ MatrixType;
500 typedef PastixBase<PastixLLT<MatrixType, UpLo_> > Base;
504 enum { UpLo = UpLo_ };
507 explicit PastixLLT(
const MatrixType &matrix) : Base() {
517 grabMatrix(matrix, temp);
527 grabMatrix(matrix, temp);
528 Base::analyzePattern(temp);
535 grabMatrix(matrix, temp);
536 Base::factorize(temp);
543 m_iparm(IPARM_SYM) = API_SYM_YES;
544 m_iparm(IPARM_FACTORIZATION) = API_FACT_LLT;
547 void grabMatrix(
const MatrixType &matrix, ColSpMatrix &out) {
548 out.
resize(matrix.rows(), matrix.cols());
550 out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
551 internal::c_to_fortran_numbering(out);
571template <
typename MatrixType_,
int UpLo_>
572class PastixLDLT :
public PastixBase<PastixLDLT<MatrixType_, UpLo_> > {
574 typedef MatrixType_ MatrixType;
575 typedef PastixBase<PastixLDLT<MatrixType, UpLo_> > Base;
579 enum { UpLo = UpLo_ };
582 explicit PastixLDLT(
const MatrixType &matrix) : Base() {
592 grabMatrix(matrix, temp);
602 grabMatrix(matrix, temp);
603 Base::analyzePattern(temp);
610 grabMatrix(matrix, temp);
611 Base::factorize(temp);
618 m_iparm(IPARM_SYM) = API_SYM_YES;
619 m_iparm(IPARM_FACTORIZATION) = API_FACT_LDLT;
622 void grabMatrix(
const MatrixType &matrix, ColSpMatrix &out) {
624 out.
resize(matrix.rows(), matrix.cols());
625 out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
626 internal::c_to_fortran_numbering(out);
A sparse direct supernodal Cholesky (LLT) factorization and solver based on the PaStiX library.
Definition PaStiXSupport.h:572
void compute(const MatrixType &matrix)
Definition PaStiXSupport.h:590
void analyzePattern(const MatrixType &matrix)
Definition PaStiXSupport.h:600
void factorize(const MatrixType &matrix)
Definition PaStiXSupport.h:608
A sparse direct supernodal Cholesky (LLT) factorization and solver based on the PaStiX library.
Definition PaStiXSupport.h:497
void compute(const MatrixType &matrix)
Definition PaStiXSupport.h:515
void factorize(const MatrixType &matrix)
Definition PaStiXSupport.h:533
void analyzePattern(const MatrixType &matrix)
Definition PaStiXSupport.h:525
Interface to the PaStix solver.
Definition PaStiXSupport.h:398
void analyzePattern(const MatrixType &matrix)
Definition PaStiXSupport.h:428
void factorize(const MatrixType &matrix)
Definition PaStiXSupport.h:440
void compute(const MatrixType &matrix)
Definition PaStiXSupport.h:417
constexpr void resize(Index rows, Index cols)
Definition PlainObjectBase.h:294
const Scalar * data() const
Definition PlainObjectBase.h:273
A versatible sparse matrix representation.
Definition SparseUtil.h:47
Index outerSize() const
Definition SparseMatrix.h:166
SparseSolverBase()
Definition SparseSolverBase.h:70
@ NumericalIssue
Definition Constants.h:442
@ InvalidInput
Definition Constants.h:447
@ Success
Definition Constants.h:440
Namespace containing all symbols from the Eigen library.
Definition Core:137