44#include "HSSBasisID.hpp"
46#include "HSSExtra.hpp"
49#include "HSSMatrix.sketch.hpp"
54#ifndef DOXYGEN_SHOULD_SKIP_THIS
56 template<
typename scalar_t>
class HSSMatrixMPI;
81 using real_t =
typename RealType<scalar_t>::value_type;
84 using elem_t =
typename std::function
85 <void(
const std::vector<std::size_t>& I,
86 const std::vector<std::size_t>& J,
DenseM_t& B)>;
87 using mult_t =
typename std::function
186 std::unique_ptr<HSSMatrixBase<scalar_t>>
clone()
const override;
274 const std::function<
void(
const std::vector<std::size_t>& I,
275 const std::vector<std::size_t>& J,
308 <
void(
const std::vector<std::size_t>& I,
309 const std::vector<std::size_t>& J,
365 bool partial)
const override;
427 scalar_t
get(std::size_t i, std::size_t j)
const;
440 const std::vector<std::size_t>& J)
const;
456 const std::vector<std::size_t>& J,
459#ifndef DOXYGEN_SHOULD_SKIP_THIS
463 void Schur_product_direct(
const DenseM_t& Theta,
466 const DenseM_t&_ThetaVhatC_or_VhatCPhiC,
469 void Schur_product_indirect(
const DenseM_t& DUB01,
474 void delete_trailing_block()
override;
477 std::size_t
rank()
const override;
484 std::size_t coff=0)
const override;
494 void shift(scalar_t sigma)
override;
497 std::size_t rlo=0, std::size_t clo=0)
const override;
505 void write(
const std::string& fname)
const;
519 const opts_t& opts,
bool active);
521 const opts_t& opts,
bool active);
524 HSSBasisID<scalar_t> U_, V_;
525 DenseM_t D_, B01_, B10_;
527 void compress_original(
const DenseM_t& A,
529 void compress_original(
const mult_t& Amult,
532 void compress_stable(
const DenseM_t& A,
534 void compress_stable(
const mult_t& Amult,
537 void compress_hard_restart(
const DenseM_t& A,
539 void compress_hard_restart(
const mult_t& Amult,
543 void compress_recursive_original(DenseM_t& Rr, DenseM_t& Rc,
544 DenseM_t& Sr, DenseM_t& Sc,
547 WorkCompress<scalar_t>& w,
548 int dd,
int depth)
override;
549 void compress_recursive_stable(DenseM_t& Rr, DenseM_t& Rc,
550 DenseM_t& Sr, DenseM_t& Sc,
553 WorkCompress<scalar_t>& w,
554 int d,
int dd,
int depth)
override;
556 void compute_local_samples(DenseM_t& Rr, DenseM_t& Rc,
557 DenseM_t& Sr, DenseM_t& Sc,
558 WorkCompress<scalar_t>& w,
559 int d0,
int d,
int depth,
560 SJLTMatrix<scalar_t, int>* S=
nullptr);
561 bool compute_U_V_bases(DenseM_t& Sr, DenseM_t& Sc,
const opts_t& opts,
562 WorkCompress<scalar_t>& w,
int d,
int depth);
563 void compute_U_basis_stable(DenseM_t& Sr,
const opts_t& opts,
564 WorkCompress<scalar_t>& w,
565 int d,
int dd,
int depth);
566 void compute_V_basis_stable(DenseM_t& Sc,
const opts_t& opts,
567 WorkCompress<scalar_t>& w,
568 int d,
int dd,
int depth);
569 void reduce_local_samples(DenseM_t& Rr, DenseM_t& Rc,
570 WorkCompress<scalar_t>& w,
571 int d0,
int d,
int depth);
572 bool update_orthogonal_basis(
const opts_t& opts, scalar_t& r_max_0,
573 const DenseM_t& S, DenseM_t& Q,
574 int d,
int dd,
bool untouched,
576 void set_U_full_rank(WorkCompress<scalar_t>& w);
577 void set_V_full_rank(WorkCompress<scalar_t>& w);
579 void compress_level_original(DenseM_t& Rr, DenseM_t& Rc,
580 DenseM_t& Sr, DenseM_t& Sc,
582 WorkCompress<scalar_t>& w,
583 int dd,
int lvl,
int depth)
override;
584 void compress_level_stable(DenseM_t& Rr, DenseM_t& Rc,
585 DenseM_t& Sr, DenseM_t& Sc,
587 WorkCompress<scalar_t>& w,
588 int d,
int dd,
int lvl,
int depth)
override;
589 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
590 std::vector<std::vector<std::size_t>>& J,
591 const std::pair<std::size_t,std::size_t>& off,
592 WorkCompress<scalar_t>& w,
593 int& self,
int lvl)
override;
594 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
595 std::vector<std::vector<std::size_t>>& J,
596 std::vector<DenseM_t*>& B,
597 const std::pair<std::size_t,std::size_t>& off,
598 WorkCompress<scalar_t>& w,
599 int& self,
int lvl)
override;
600 void extract_D_B(
const elem_t& Aelem,
const opts_t& opts,
601 WorkCompress<scalar_t>& w,
int lvl)
override;
606 const elem_t& Aelem,
const opts_t& opts,
607 WorkCompressANN<scalar_t>& w,
611 WorkCompressANN<scalar_t>& w,
612 const elem_t& Aelem,
const opts_t& opts);
613 bool compute_U_V_bases_ann(DenseM_t& S,
const opts_t& opts,
614 WorkCompressANN<scalar_t>& w,
int depth);
616 void factor_recursive(WorkFactor<scalar_t>& w,
617 bool isroot,
bool partial,
620 void apply_fwd(
const DenseM_t& b, WorkApply<scalar_t>& w,
621 bool isroot,
int depth,
622 std::atomic<long long int>& flops)
const override;
623 void apply_bwd(
const DenseM_t& b, scalar_t beta, DenseM_t& c,
624 WorkApply<scalar_t>& w,
bool isroot,
int depth,
625 std::atomic<long long int>& flops)
const override;
626 void applyT_fwd(
const DenseM_t& b, WorkApply<scalar_t>& w,
bool isroot,
627 int depth, std::atomic<long long int>& flops)
const override;
628 void applyT_bwd(
const DenseM_t& b, scalar_t beta, DenseM_t& c,
629 WorkApply<scalar_t>& w,
bool isroot,
int depth,
630 std::atomic<long long int>& flops)
const override;
632 void solve_fwd(
const DenseM_t& b, WorkSolve<scalar_t>& w,
633 bool partial,
bool isroot,
int depth)
const override;
634 void solve_bwd(DenseM_t& x, WorkSolve<scalar_t>& w,
635 bool isroot,
int depth)
const override;
637 void extract_fwd(WorkExtract<scalar_t>& w,
638 bool odiag,
int depth)
const override;
639 void extract_bwd(DenseM_t& B, WorkExtract<scalar_t>& w,
640 int depth)
const override;
641 void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
642 WorkExtract<scalar_t>& w,
int depth)
const override;
643 void extract_bwd_internal(WorkExtract<scalar_t>& w,
int depth)
const;
645 void apply_UV_big(DenseM_t& Theta, DenseM_t& Uop, DenseM_t& Phi,
647 const std::pair<std::size_t, std::size_t>& offset,
648 int depth, std::atomic<long long int>& flops)
650 void apply_UtVt_big(
const DenseM_t& A, DenseM_t& UtA, DenseM_t& VtA,
651 const std::pair<std::size_t, std::size_t>& offset,
652 int depth, std::atomic<long long int>& flops)
655 void dense_recursive(DenseM_t& A, WorkDense<scalar_t>& w,
656 bool isroot,
int depth)
const override;
661 template<
typename T>
friend
668 template<
typename T>
friend
671 void read(std::ifstream& is)
override;
672 void write(std::ofstream& os)
const override;
686 template<
typename scalar_t>
700 template<
typename scalar_t>
void
This file contains the HSSMatrixBase class definition, an abstract class for HSS matrix representatio...
Contains the HSSOptions class as well as general routines for HSS options.
Definitions of several kernel functions, and helper routines. Also provides driver routines for kerne...
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition: DenseMatrix.hpp:1015
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition: DenseMatrix.hpp:138
Contains data related to ULV factorization of an HSS matrix.
Definition: HSSExtra.hpp:161
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition: HSSMatrixBase.hpp:83
bool active() const
Definition: HSSMatrixBase.hpp:239
Distributed memory implementation of the HSS (Hierarchically Semi-Separable) matrix format.
Definition: HSSMatrixMPI.hpp:68
Class to represent a sequential/threaded Hierarchically Semi-Separable matrix.
Definition: HSSMatrix.hpp:80
void forward_solve(WorkSolve< scalar_t > &w, const DenseM_t &b, bool partial) const override
std::size_t levels() const override
void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const override
const HSSMatrix< scalar_t > * child(int c) const
Definition: HSSMatrix.hpp:194
HSSMatrix(const DenseM_t &A, const opts_t &opts)
friend void draw(const HSSMatrix< T > &H, const std::string &name)
void draw(std::ostream &of, std::size_t rlo=0, std::size_t clo=0) const override
void extract_add(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B) const
HSSMatrix< scalar_t > & operator=(HSSMatrix< scalar_t > &&other)=default
scalar_t get(std::size_t i, std::size_t j) const
HSSMatrix< scalar_t > & operator=(const HSSMatrix< scalar_t > &other)
void shift(scalar_t sigma) override
void compress_with_coordinates(const DenseMatrix< real_t > &coords, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
void compress(const std::function< void(DenseM_t &Rr, DenseM_t &Rc, DenseM_t &Sr, DenseM_t &Sc)> &Amult, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
std::size_t memory() const override
DenseM_t applyC(const DenseM_t &b) const
static HSSMatrix< scalar_t > read(const std::string &fname)
HSSMatrix(HSSMatrix< scalar_t > &&other)=default
std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const override
void write(const std::string &fname) const
HSSMatrix(const HSSMatrix< scalar_t > &other)
HSSMatrix(kernel::Kernel< real_t > &K, const opts_t &opts)
void backward_solve(WorkSolve< scalar_t > &w, DenseM_t &x) const override
friend void apply_HSS(Trans ta, const HSSMatrix< T > &a, const DenseMatrix< T > &b, T beta, DenseMatrix< T > &c)
DenseM_t extract(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J) const
void solve(DenseM_t &b) const override
void compress(const DenseM_t &A, const opts_t &opts)
HSSMatrix(std::size_t m, std::size_t n, const opts_t &opts)
std::size_t rank() const override
HSSMatrix(const structured::ClusterTree &t, const opts_t &opts)
void mult(Trans op, const DenseM_t &x, DenseM_t &y) const override
HSSMatrix< scalar_t > * child(int c)
Definition: HSSMatrix.hpp:204
std::size_t nonzeros() const override
DenseM_t apply(const DenseM_t &b) const
Class containing several options for the HSS code and data-structures.
Definition: HSSOptions.hpp:152
Representation of a kernel matrix.
Definition: Kernel.hpp:73
The cluster tree, or partition tree that represents the partitioning of the rows or columns of a hier...
Definition: ClusterTree.hpp:67
void apply_HSS(Trans op, const HSSMatrix< scalar_t > &A, const DenseMatrix< scalar_t > &B, scalar_t beta, DenseMatrix< scalar_t > &C)
void draw(const HSSMatrix< scalar_t > &H, const std::string &name)
Definition: StrumpackOptions.hpp:43
Trans
Definition: DenseMatrix.hpp:51