37 #ifndef HSS_MATRIX_HPP
38 #define HSS_MATRIX_HPP
44 #include "HSSBasisID.hpp"
46 #include "HSSExtra.hpp"
53 #ifndef DOXYGEN_SHOULD_SKIP_THIS
55 template<
typename scalar_t>
class HSSMatrixMPI;
80 using real_t =
typename RealType<scalar_t>::value_type;
83 using elem_t =
typename std::function
84 <void(
const std::vector<std::size_t>& I,
85 const std::vector<std::size_t>& J,
DenseM_t& B)>;
86 using mult_t =
typename std::function
185 std::unique_ptr<HSSMatrixBase<scalar_t>>
clone()
const override;
273 const std::function<
void(
const std::vector<std::size_t>& I,
274 const std::vector<std::size_t>& J,
307 <
void(
const std::vector<std::size_t>& I,
308 const std::vector<std::size_t>& J,
364 bool partial)
const override;
426 scalar_t
get(std::size_t i, std::size_t j)
const;
439 const std::vector<std::size_t>& J)
const;
455 const std::vector<std::size_t>& J,
458 #ifndef DOXYGEN_SHOULD_SKIP_THIS
462 void Schur_product_direct(
const DenseM_t& Theta,
465 const DenseM_t&_ThetaVhatC_or_VhatCPhiC,
468 void Schur_product_indirect(
const DenseM_t& DUB01,
473 void delete_trailing_block()
override;
476 std::size_t
rank()
const override;
483 std::size_t coff=0)
const override;
493 void shift(scalar_t sigma)
override;
496 std::size_t rlo=0, std::size_t clo=0)
const override;
504 void write(
const std::string& fname)
const;
518 const opts_t& opts,
bool active);
520 const opts_t& opts,
bool active);
523 HSSBasisID<scalar_t> U_, V_;
524 DenseM_t D_, B01_, B10_;
526 void compress_original(
const DenseM_t& A,
528 void compress_original(
const mult_t& Amult,
531 void compress_stable(
const DenseM_t& A,
533 void compress_stable(
const mult_t& Amult,
536 void compress_hard_restart(
const DenseM_t& A,
538 void compress_hard_restart(
const mult_t& Amult,
542 void compress_recursive_original(DenseM_t& Rr, DenseM_t& Rc,
543 DenseM_t& Sr, DenseM_t& Sc,
546 WorkCompress<scalar_t>& w,
547 int dd,
int depth)
override;
548 void compress_recursive_stable(DenseM_t& Rr, DenseM_t& Rc,
549 DenseM_t& Sr, DenseM_t& Sc,
552 WorkCompress<scalar_t>& w,
553 int d,
int dd,
int depth)
override;
554 void compute_local_samples(DenseM_t& Rr, DenseM_t& Rc,
555 DenseM_t& Sr, DenseM_t& Sc,
556 WorkCompress<scalar_t>& w,
557 int d0,
int d,
int depth);
558 bool compute_U_V_bases(DenseM_t& Sr, DenseM_t& Sc,
const opts_t& opts,
559 WorkCompress<scalar_t>& w,
int d,
int depth);
560 void compute_U_basis_stable(DenseM_t& Sr,
const opts_t& opts,
561 WorkCompress<scalar_t>& w,
562 int d,
int dd,
int depth);
563 void compute_V_basis_stable(DenseM_t& Sc,
const opts_t& opts,
564 WorkCompress<scalar_t>& w,
565 int d,
int dd,
int depth);
566 void reduce_local_samples(DenseM_t& Rr, DenseM_t& Rc,
567 WorkCompress<scalar_t>& w,
568 int d0,
int d,
int depth);
569 bool update_orthogonal_basis(
const opts_t& opts, scalar_t& r_max_0,
570 const DenseM_t& S, DenseM_t& Q,
571 int d,
int dd,
bool untouched,
573 void set_U_full_rank(WorkCompress<scalar_t>& w);
574 void set_V_full_rank(WorkCompress<scalar_t>& w);
576 void compress_level_original(DenseM_t& Rr, DenseM_t& Rc,
577 DenseM_t& Sr, DenseM_t& Sc,
579 WorkCompress<scalar_t>& w,
580 int dd,
int lvl,
int depth)
override;
581 void compress_level_stable(DenseM_t& Rr, DenseM_t& Rc,
582 DenseM_t& Sr, DenseM_t& Sc,
584 WorkCompress<scalar_t>& w,
585 int d,
int dd,
int lvl,
int depth)
override;
586 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
587 std::vector<std::vector<std::size_t>>& J,
588 const std::pair<std::size_t,std::size_t>& off,
589 WorkCompress<scalar_t>& w,
590 int&
self,
int lvl)
override;
591 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
592 std::vector<std::vector<std::size_t>>& J,
593 std::vector<DenseM_t*>& B,
594 const std::pair<std::size_t,std::size_t>& off,
595 WorkCompress<scalar_t>& w,
596 int&
self,
int lvl)
override;
597 void extract_D_B(
const elem_t& Aelem,
const opts_t& opts,
598 WorkCompress<scalar_t>& w,
int lvl)
override;
603 const elem_t& Aelem,
const opts_t& opts,
604 WorkCompressANN<scalar_t>& w,
608 WorkCompressANN<scalar_t>& w,
609 const elem_t& Aelem,
const opts_t& opts);
610 bool compute_U_V_bases_ann(DenseM_t& S,
const opts_t& opts,
611 WorkCompressANN<scalar_t>& w,
int depth);
613 void factor_recursive(WorkFactor<scalar_t>& w,
614 bool isroot,
bool partial,
617 void apply_fwd(
const DenseM_t& b, WorkApply<scalar_t>& w,
618 bool isroot,
int depth,
619 std::atomic<long long int>& flops)
const override;
620 void apply_bwd(
const DenseM_t& b, scalar_t beta, DenseM_t& c,
621 WorkApply<scalar_t>& w,
bool isroot,
int depth,
622 std::atomic<long long int>& flops)
const override;
623 void applyT_fwd(
const DenseM_t& b, WorkApply<scalar_t>& w,
bool isroot,
624 int depth, std::atomic<long long int>& flops)
const override;
625 void applyT_bwd(
const DenseM_t& b, scalar_t beta, DenseM_t& c,
626 WorkApply<scalar_t>& w,
bool isroot,
int depth,
627 std::atomic<long long int>& flops)
const override;
629 void solve_fwd(
const DenseM_t& b, WorkSolve<scalar_t>& w,
630 bool partial,
bool isroot,
int depth)
const override;
631 void solve_bwd(DenseM_t& x, WorkSolve<scalar_t>& w,
632 bool isroot,
int depth)
const override;
634 void extract_fwd(WorkExtract<scalar_t>& w,
635 bool odiag,
int depth)
const override;
636 void extract_bwd(DenseM_t& B, WorkExtract<scalar_t>& w,
637 int depth)
const override;
638 void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
639 WorkExtract<scalar_t>& w,
int depth)
const override;
640 void extract_bwd_internal(WorkExtract<scalar_t>& w,
int depth)
const;
642 void apply_UV_big(DenseM_t& Theta, DenseM_t& Uop, DenseM_t& Phi,
644 const std::pair<std::size_t, std::size_t>& offset,
645 int depth, std::atomic<long long int>& flops)
647 void apply_UtVt_big(
const DenseM_t& A, DenseM_t& UtA, DenseM_t& VtA,
648 const std::pair<std::size_t, std::size_t>& offset,
649 int depth, std::atomic<long long int>& flops)
652 void dense_recursive(DenseM_t& A, WorkDense<scalar_t>& w,
653 bool isroot,
int depth)
const override;
658 template<
typename T>
friend
665 template<
typename T>
friend
668 void read(std::ifstream& is)
override;
669 void write(std::ofstream& os)
const override;
683 template<
typename scalar_t>
697 template<
typename scalar_t>
void
This file contains the HSSMatrixBase class definition, an abstract class for HSS matrix representatio...
Contains the HSSOptions class as well as general routines for HSS options.
Definitions of several kernel functions, and helper routines. Also provides driver routines for kerne...
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition: DenseMatrix.hpp:1015
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition: DenseMatrix.hpp:138
Contains data related to ULV factorization of an HSS matrix.
Definition: HSSExtra.hpp:161
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition: HSSMatrixBase.hpp:83
bool active() const
Definition: HSSMatrixBase.hpp:239
Distributed memory implementation of the HSS (Hierarchically Semi-Separable) matrix format.
Definition: HSSMatrixMPI.hpp:68
Class to represent a sequential/threaded Hierarchically Semi-Separable matrix.
Definition: HSSMatrix.hpp:79
void forward_solve(WorkSolve< scalar_t > &w, const DenseM_t &b, bool partial) const override
std::size_t levels() const override
void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const override
const HSSMatrix< scalar_t > * child(int c) const
Definition: HSSMatrix.hpp:193
HSSMatrix(const DenseM_t &A, const opts_t &opts)
friend void draw(const HSSMatrix< T > &H, const std::string &name)
void draw(std::ostream &of, std::size_t rlo=0, std::size_t clo=0) const override
void extract_add(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B) const
scalar_t get(std::size_t i, std::size_t j) const
void shift(scalar_t sigma) override
void compress_with_coordinates(const DenseMatrix< real_t > &coords, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const override
void compress(const std::function< void(DenseM_t &Rr, DenseM_t &Rc, DenseM_t &Sr, DenseM_t &Sc)> &Amult, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
std::size_t memory() const override
DenseM_t applyC(const DenseM_t &b) const
HSSMatrix(HSSMatrix< scalar_t > &&other)=default
void write(const std::string &fname) const
static HSSMatrix< scalar_t > read(const std::string &fname)
HSSMatrix(const HSSMatrix< scalar_t > &other)
HSSMatrix(kernel::Kernel< real_t > &K, const opts_t &opts)
HSSMatrix< scalar_t > * child(int c)
Definition: HSSMatrix.hpp:203
void backward_solve(WorkSolve< scalar_t > &w, DenseM_t &x) const override
friend void apply_HSS(Trans ta, const HSSMatrix< T > &a, const DenseMatrix< T > &b, T beta, DenseMatrix< T > &c)
HSSMatrix< scalar_t > & operator=(const HSSMatrix< scalar_t > &other)
DenseM_t extract(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J) const
void solve(DenseM_t &b) const override
void compress(const DenseM_t &A, const opts_t &opts)
HSSMatrix(std::size_t m, std::size_t n, const opts_t &opts)
std::size_t rank() const override
HSSMatrix(const structured::ClusterTree &t, const opts_t &opts)
void mult(Trans op, const DenseM_t &x, DenseM_t &y) const override
HSSMatrix< scalar_t > & operator=(HSSMatrix< scalar_t > &&other)=default
std::size_t nonzeros() const override
DenseM_t apply(const DenseM_t &b) const
Class containing several options for the HSS code and data-structures.
Definition: HSSOptions.hpp:118
Representation of a kernel matrix.
Definition: Kernel.hpp:73
The cluster tree, or partition tree that represents the partitioning of the rows or columns of a hier...
Definition: ClusterTree.hpp:62
void apply_HSS(Trans op, const HSSMatrix< scalar_t > &A, const DenseMatrix< scalar_t > &B, scalar_t beta, DenseMatrix< scalar_t > &C)
void draw(const HSSMatrix< scalar_t > &H, const std::string &name)
Definition: StrumpackOptions.hpp:42
Trans
Definition: DenseMatrix.hpp:51