37 #ifndef HSS_MATRIX_HPP 
   38 #define HSS_MATRIX_HPP 
   44 #include "HSSBasisID.hpp" 
   46 #include "HSSExtra.hpp" 
   49 #include "HSSMatrix.sketch.hpp" 
   54 #ifndef DOXYGEN_SHOULD_SKIP_THIS 
   56     template<
typename scalar_t> 
class HSSMatrixMPI;
 
   81       using real_t = 
typename RealType<scalar_t>::value_type;
 
   84       using elem_t = 
typename std::function
 
   85         <void(
const std::vector<std::size_t>& I,
 
   86               const std::vector<std::size_t>& J, 
DenseM_t& B)>;
 
   87       using mult_t = 
typename std::function
 
  186       std::unique_ptr<HSSMatrixBase<scalar_t>> 
clone() 
const override;
 
  274                     const std::function<
void(
const std::vector<std::size_t>& I,
 
  275                                              const std::vector<std::size_t>& J,
 
  308                                      <
void(
const std::vector<std::size_t>& I,
 
  309                                            const std::vector<std::size_t>& J,
 
  365                          bool partial) 
const override;
 
  427       scalar_t 
get(std::size_t i, std::size_t j) 
const;
 
  440                        const std::vector<std::size_t>& J) 
const;
 
  456                        const std::vector<std::size_t>& J,
 
  459 #ifndef DOXYGEN_SHOULD_SKIP_THIS 
  463       void Schur_product_direct(
const DenseM_t& Theta,
 
  466                                 const DenseM_t&_ThetaVhatC_or_VhatCPhiC,
 
  469       void Schur_product_indirect(
const DenseM_t& DUB01,
 
  474       void delete_trailing_block() 
override;
 
  477       std::size_t 
rank() 
const override;
 
  484                       std::size_t coff=0) 
const override;
 
  494       void shift(scalar_t sigma) 
override;
 
  497                 std::size_t rlo=0, std::size_t clo=0) 
const override;
 
  505       void write(
const std::string& fname) 
const;
 
  519                 const opts_t& opts, 
bool active);
 
  521                 const opts_t& opts, 
bool active);
 
  524       HSSBasisID<scalar_t> U_, V_;
 
  525       DenseM_t D_, B01_, B10_;
 
  527       void compress_original(
const DenseM_t& A,
 
  529       void compress_original(
const mult_t& Amult,
 
  532       void compress_stable(
const DenseM_t& A,
 
  534       void compress_stable(
const mult_t& Amult,
 
  537       void compress_hard_restart(
const DenseM_t& A,
 
  539       void compress_hard_restart(
const mult_t& Amult,
 
  543       void compress_recursive_original(DenseM_t& Rr, DenseM_t& Rc,
 
  544                                        DenseM_t& Sr, DenseM_t& Sc,
 
  547                                        WorkCompress<scalar_t>& w,
 
  548                                        int dd, 
int depth) 
override;
 
  549       void compress_recursive_stable(DenseM_t& Rr, DenseM_t& Rc,
 
  550                                      DenseM_t& Sr, DenseM_t& Sc,
 
  553                                      WorkCompress<scalar_t>& w,
 
  554                                      int d, 
int dd, 
int depth) 
override;
 
  556       void compute_local_samples(DenseM_t& Rr, DenseM_t& Rc,
 
  557                                  DenseM_t& Sr, DenseM_t& Sc,
 
  558                                  WorkCompress<scalar_t>& w,
 
  559                                  int d0, 
int d, 
int depth,
 
  560                                  SJLTMatrix<scalar_t, int>* S=
nullptr);
 
  561       bool compute_U_V_bases(DenseM_t& Sr, DenseM_t& Sc, 
const opts_t& opts,
 
  562                              WorkCompress<scalar_t>& w, 
int d, 
int depth);
 
  563       void compute_U_basis_stable(DenseM_t& Sr, 
const opts_t& opts,
 
  564                                   WorkCompress<scalar_t>& w,
 
  565                                   int d, 
int dd, 
int depth);
 
  566       void compute_V_basis_stable(DenseM_t& Sc, 
const opts_t& opts,
 
  567                                   WorkCompress<scalar_t>& w,
 
  568                                   int d, 
int dd, 
int depth);
 
  569       void reduce_local_samples(DenseM_t& Rr, DenseM_t& Rc,
 
  570                                 WorkCompress<scalar_t>& w,
 
  571                                 int d0, 
int d, 
int depth);
 
  572       bool update_orthogonal_basis(
const opts_t& opts, scalar_t& r_max_0,
 
  573                                    const DenseM_t& S, DenseM_t& Q,
 
  574                                    int d, 
int dd, 
bool untouched,
 
  576       void set_U_full_rank(WorkCompress<scalar_t>& w);
 
  577       void set_V_full_rank(WorkCompress<scalar_t>& w);
 
  579       void compress_level_original(DenseM_t& Rr, DenseM_t& Rc,
 
  580                                    DenseM_t& Sr, DenseM_t& Sc,
 
  582                                    WorkCompress<scalar_t>& w,
 
  583                                    int dd, 
int lvl, 
int depth) 
override;
 
  584       void compress_level_stable(DenseM_t& Rr, DenseM_t& Rc,
 
  585                                  DenseM_t& Sr, DenseM_t& Sc,
 
  587                                  WorkCompress<scalar_t>& w,
 
  588                                  int d, 
int dd, 
int lvl, 
int depth) 
override;
 
  589       void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
 
  590                                   std::vector<std::vector<std::size_t>>& J,
 
  591                                   const std::pair<std::size_t,std::size_t>& off,
 
  592                                   WorkCompress<scalar_t>& w,
 
  593                                   int& 
self, 
int lvl) 
override;
 
  594       void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
 
  595                                   std::vector<std::vector<std::size_t>>& J,
 
  596                                   std::vector<DenseM_t*>& B,
 
  597                                   const std::pair<std::size_t,std::size_t>& off,
 
  598                                   WorkCompress<scalar_t>& w,
 
  599                                   int& 
self, 
int lvl) 
override;
 
  600       void extract_D_B(
const elem_t& Aelem, 
const opts_t& opts,
 
  601                        WorkCompress<scalar_t>& w, 
int lvl) 
override;
 
  606                                   const elem_t& Aelem, 
const opts_t& opts,
 
  607                                   WorkCompressANN<scalar_t>& w,
 
  611                                      WorkCompressANN<scalar_t>& w,
 
  612                                      const elem_t& Aelem, 
const opts_t& opts);
 
  613       bool compute_U_V_bases_ann(DenseM_t& S, 
const opts_t& opts,
 
  614                                  WorkCompressANN<scalar_t>& w, 
int depth);
 
  616       void factor_recursive(WorkFactor<scalar_t>& w,
 
  617                             bool isroot, 
bool partial,
 
  620       void apply_fwd(
const DenseM_t& b, WorkApply<scalar_t>& w,
 
  621                      bool isroot, 
int depth,
 
  622                      std::atomic<long long int>& flops) 
const override;
 
  623       void apply_bwd(
const DenseM_t& b, scalar_t beta, DenseM_t& c,
 
  624                      WorkApply<scalar_t>& w, 
bool isroot, 
int depth,
 
  625                      std::atomic<long long int>& flops) 
const override;
 
  626       void applyT_fwd(
const DenseM_t& b, WorkApply<scalar_t>& w, 
bool isroot,
 
  627                       int depth, std::atomic<long long int>& flops) 
const override;
 
  628       void applyT_bwd(
const DenseM_t& b, scalar_t beta, DenseM_t& c,
 
  629                       WorkApply<scalar_t>& w, 
bool isroot, 
int depth,
 
  630                       std::atomic<long long int>& flops) 
const override;
 
  632       void solve_fwd(
const DenseM_t& b, WorkSolve<scalar_t>& w,
 
  633                      bool partial, 
bool isroot, 
int depth) 
const override;
 
  634       void solve_bwd(DenseM_t& x, WorkSolve<scalar_t>& w,
 
  635                      bool isroot, 
int depth) 
const override;
 
  637       void extract_fwd(WorkExtract<scalar_t>& w,
 
  638                        bool odiag, 
int depth) 
const override;
 
  639       void extract_bwd(DenseM_t& B, WorkExtract<scalar_t>& w,
 
  640                        int depth) 
const override;
 
  641       void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
 
  642                        WorkExtract<scalar_t>& w, 
int depth) 
const override;
 
  643       void extract_bwd_internal(WorkExtract<scalar_t>& w, 
int depth) 
const;
 
  645       void apply_UV_big(DenseM_t& Theta, DenseM_t& Uop, DenseM_t& Phi,
 
  647                         const std::pair<std::size_t, std::size_t>& offset,
 
  648                         int depth, std::atomic<long long int>& flops)
 
  650       void apply_UtVt_big(
const DenseM_t& A, DenseM_t& UtA, DenseM_t& VtA,
 
  651                           const std::pair<std::size_t, std::size_t>& offset,
 
  652                           int depth, std::atomic<long long int>& flops)
 
  655       void dense_recursive(DenseM_t& A, WorkDense<scalar_t>& w,
 
  656                            bool isroot, 
int depth) 
const override;
 
  661       template<
typename T> 
friend 
  668       template<
typename T> 
friend 
  671       void read(std::ifstream& is) 
override;
 
  672       void write(std::ofstream& os) 
const override;
 
  686     template<
typename scalar_t>
 
  700     template<
typename scalar_t> 
void 
This file contains the HSSMatrixBase class definition, an abstract class for HSS matrix representatio...
Contains the HSSOptions class as well as general routines for HSS options.
Definitions of several kernel functions, and helper routines. Also provides driver routines for kerne...
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition: DenseMatrix.hpp:1015
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition: DenseMatrix.hpp:138
Contains data related to ULV factorization of an HSS matrix.
Definition: HSSExtra.hpp:161
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition: HSSMatrixBase.hpp:83
bool active() const
Definition: HSSMatrixBase.hpp:239
Distributed memory implementation of the HSS (Hierarchically Semi-Separable) matrix format.
Definition: HSSMatrixMPI.hpp:68
Class to represent a sequential/threaded Hierarchically Semi-Separable matrix.
Definition: HSSMatrix.hpp:80
void forward_solve(WorkSolve< scalar_t > &w, const DenseM_t &b, bool partial) const override
std::size_t levels() const override
void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const override
const HSSMatrix< scalar_t > * child(int c) const
Definition: HSSMatrix.hpp:194
HSSMatrix(const DenseM_t &A, const opts_t &opts)
friend void draw(const HSSMatrix< T > &H, const std::string &name)
void draw(std::ostream &of, std::size_t rlo=0, std::size_t clo=0) const override
void extract_add(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B) const
scalar_t get(std::size_t i, std::size_t j) const
void shift(scalar_t sigma) override
void compress_with_coordinates(const DenseMatrix< real_t > &coords, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const override
void compress(const std::function< void(DenseM_t &Rr, DenseM_t &Rc, DenseM_t &Sr, DenseM_t &Sc)> &Amult, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
std::size_t memory() const override
DenseM_t applyC(const DenseM_t &b) const
HSSMatrix(HSSMatrix< scalar_t > &&other)=default
void write(const std::string &fname) const
static HSSMatrix< scalar_t > read(const std::string &fname)
HSSMatrix(const HSSMatrix< scalar_t > &other)
HSSMatrix(kernel::Kernel< real_t > &K, const opts_t &opts)
HSSMatrix< scalar_t > * child(int c)
Definition: HSSMatrix.hpp:204
void backward_solve(WorkSolve< scalar_t > &w, DenseM_t &x) const override
friend void apply_HSS(Trans ta, const HSSMatrix< T > &a, const DenseMatrix< T > &b, T beta, DenseMatrix< T > &c)
HSSMatrix< scalar_t > & operator=(const HSSMatrix< scalar_t > &other)
DenseM_t extract(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J) const
void solve(DenseM_t &b) const override
void compress(const DenseM_t &A, const opts_t &opts)
HSSMatrix(std::size_t m, std::size_t n, const opts_t &opts)
std::size_t rank() const override
HSSMatrix(const structured::ClusterTree &t, const opts_t &opts)
void mult(Trans op, const DenseM_t &x, DenseM_t &y) const override
HSSMatrix< scalar_t > & operator=(HSSMatrix< scalar_t > &&other)=default
std::size_t nonzeros() const override
DenseM_t apply(const DenseM_t &b) const
Class containing several options for the HSS code and data-structures.
Definition: HSSOptions.hpp:152
Representation of a kernel matrix.
Definition: Kernel.hpp:73
The cluster tree, or partition tree that represents the partitioning of the rows or columns of a hier...
Definition: ClusterTree.hpp:67
void apply_HSS(Trans op, const HSSMatrix< scalar_t > &A, const DenseMatrix< scalar_t > &B, scalar_t beta, DenseMatrix< scalar_t > &C)
void draw(const HSSMatrix< scalar_t > &H, const std::string &name)
Definition: StrumpackOptions.hpp:43
Trans
Definition: DenseMatrix.hpp:51