36 #ifndef HSS_MATRIX_MPI_HPP 
   37 #define HSS_MATRIX_MPI_HPP 
   43 #include "HSSExtraMPI.hpp" 
   44 #include "DistSamples.hpp" 
   45 #include "DistElemMult.hpp" 
   46 #include "HSSBasisIDMPI.hpp" 
   69       using real_t = 
typename RealType<scalar_t>::value_type;
 
   74       using delem_t = 
typename std::function
 
   75         <void(
const std::vector<std::size_t>& I,
 
   76               const std::vector<std::size_t>& J, 
DistM_t& B)>;
 
   77       using delem_blocks_t = 
typename std::function
 
   78         <void(
const std::vector<std::vector<std::size_t>>& I,
 
   79               const std::vector<std::vector<std::size_t>>& J,
 
   80               std::vector<DistMW_t>& B)>;
 
   81       using elem_t = 
typename std::function
 
   82         <void(
const std::vector<std::size_t>& I,
 
   83               const std::vector<std::size_t>& J, 
DenseM_t& B)>;
 
   84       using dmult_t = 
typename std::function
 
   96                    const dmult_t& Amult, 
const delem_t& Aelem,
 
   99                    const dmult_t& Amult, 
const delem_blocks_t& Aelem,
 
  102                    const dmult_t& Amult, 
const delem_t& Aelem,
 
  112       std::unique_ptr<HSSMatrixBase<scalar_t>> 
clone() 
const override;
 
  115         return this->ch_[c].get();
 
  119       const BLACSGrid* grid()
 const override { 
return blacs_grid_; }
 
  120       const BLACSGrid* grid(
const BLACSGrid* grid)
 const override { 
return blacs_grid_; }
 
  121       const BLACSGrid* grid_local()
 const override { 
return blacs_grid_local_; }
 
  122       const MPIComm& Comm()
 const { 
return grid()->
Comm(); }
 
  123       MPI_Comm comm()
 const { 
return Comm().
comm(); }
 
  124       int Ptotal()
 const override { 
return grid()->
P(); }
 
  125       int Pactive()
 const override { 
return grid()->
npactives(); }
 
  128       void compress(
const DistM_t& A,
 
  130       void compress(
const dmult_t& Amult,
 
  131                     const delem_t& Aelem,
 
  133       void compress(
const dmult_t& Amult,
 
  134                     const delem_blocks_t& Aelem,
 
  136       void compress(
const kernel::Kernel<real_t>& K, 
const opts_t& opts);
 
  139       void partial_factor();
 
  141       void forward_solve(WorkSolveMPI<scalar_t>& w, 
const DistM_t& b,
 
  142                          bool partial) 
const override;
 
  143       void backward_solve(WorkSolveMPI<scalar_t>& w,
 
  151       scalar_t get(std::size_t i, std::size_t j) 
const;
 
  152       DistM_t extract(
const std::vector<std::size_t>& I,
 
  153                       const std::vector<std::size_t>& J,
 
  156       extract(
const std::vector<std::vector<std::size_t>>& I,
 
  157               const std::vector<std::vector<std::size_t>>& J,
 
  159       void extract_add(
const std::vector<std::size_t>& I,
 
  160                        const std::vector<std::size_t>& J, 
DistM_t& B) 
const;
 
  161       void extract_add(
const std::vector<std::vector<std::size_t>>& I,
 
  162                        const std::vector<std::vector<std::size_t>>& J,
 
  163                        std::vector<DistM_t>& B) 
const;
 
  167       void Schur_product_direct(
const DistM_t& Theta,
 
  176       std::size_t max_rank() 
const;        
 
  177       std::size_t total_memory() 
const;    
 
  178       std::size_t total_nonzeros() 
const;  
 
  179       std::size_t total_factor_nonzeros() 
const;  
 
  180       std::size_t max_levels() 
const;      
 
  181       std::size_t 
rank() 
const override;
 
  184       std::size_t factor_nonzeros() 
const override;
 
  189                       std::size_t coff=0) 
const override;
 
  193       void shift(scalar_t sigma) 
override;
 
  196       void to_block_row(
const DistM_t& A,
 
  198                         DistM_t& leaf_A) 
const override;
 
  199       void allocate_block_row(
int d, DenseM_t& sub_A,
 
  200                               DistM_t& leaf_A) 
const override;
 
  201       void from_block_row(DistM_t& A,
 
  202                           const DenseM_t& sub_A,
 
  203                           const DistM_t& leaf_A,
 
  206       void delete_trailing_block() 
override;
 
  207       void reset() 
override;
 
  210       using delemw_t = 
typename std::function
 
  211         <void(
const std::vector<std::size_t>& I,
 
  212               const std::vector<std::size_t>& J,
 
  213               DistM_t& B, DistM_t& A,
 
  214               std::size_t rlo, std::size_t clo,
 
  219       std::unique_ptr<const BLACSGrid> owned_blacs_grid_;
 
  220       std::unique_ptr<const BLACSGrid> owned_blacs_grid_local_;
 
  224       HSSBasisIDMPI<scalar_t> U_, V_;
 
  225       DistM_t D_, B01_, B10_;
 
  229       DistM_t A_, A01_, A10_;
 
  231       HSSMatrixMPI(std::size_t m, std::size_t n, 
const opts_t& opts,
 
  233                    std::size_t roff, std::size_t coff);
 
  236                    std::size_t roff, std::size_t coff);
 
  237       void setup_hierarchy(
const opts_t& opts,
 
  238                            std::size_t roff, std::size_t coff);
 
  240                            std::size_t roff, std::size_t coff);
 
  241       void setup_local_context();
 
  242       void setup_ranges(std::size_t roff, std::size_t coff);
 
  244       void compress_original_nosync(
const dmult_t& Amult,
 
  245                                     const delemw_t& Aelem,
 
  247       void compress_original_sync(
const dmult_t& Amult,
 
  248                                   const delemw_t& Aelem,
 
  250       void compress_original_sync(
const dmult_t& Amult,
 
  251                                   const delem_blocks_t& Aelem,
 
  253       void compress_stable_nosync(
const dmult_t& Amult,
 
  254                                   const delemw_t& Aelem,
 
  256       void compress_stable_sync(
const dmult_t& Amult,
 
  257                                 const delemw_t& Aelem,
 
  259       void compress_stable_sync(
const dmult_t& Amult,
 
  260                                 const delem_blocks_t& Aelem,
 
  262       void compress_hard_restart_nosync(
const dmult_t& Amult,
 
  263                                         const delemw_t& Aelem,
 
  265       void compress_hard_restart_sync(
const dmult_t& Amult,
 
  266                                       const delemw_t& Aelem,
 
  268       void compress_hard_restart_sync(
const dmult_t& Amult,
 
  269                                       const delem_blocks_t& Aelem,
 
  274                                   const delemw_t& Aelem,
 
  275                                   WorkCompressMPIANN<scalar_t>& w,
 
  280                                      WorkCompressMPIANN<scalar_t>& w,
 
  281                                      const delemw_t& Aelem,
 
  283       bool compute_U_V_bases_ann(DistM_t& S, 
const opts_t& opts,
 
  284                                  WorkCompressMPIANN<scalar_t>& w);
 
  285       void communicate_child_data_ann(WorkCompressMPIANN<scalar_t>& w);
 
  287       void compress_recursive_original(DistSamples<scalar_t>& RS,
 
  288                                        const delemw_t& Aelem,
 
  290                                        WorkCompressMPI<scalar_t>& w,
 
  292       void compress_recursive_stable(DistSamples<scalar_t>& RS,
 
  293                                      const delemw_t& Aelem,
 
  295                                      WorkCompressMPI<scalar_t>& w,
 
  296                                      int d, 
int dd) 
override;
 
  297       void compute_local_samples(
const DistSamples<scalar_t>& RS,
 
  298                                  WorkCompressMPI<scalar_t>& w, 
int dd);
 
  299       bool compute_U_V_bases(
int d, 
const opts_t& opts,
 
  300                              WorkCompressMPI<scalar_t>& w);
 
  301       void compute_U_basis_stable(
const opts_t& opts,
 
  302                                   WorkCompressMPI<scalar_t>& w,
 
  304       void compute_V_basis_stable(
const opts_t& opts,
 
  305                                   WorkCompressMPI<scalar_t>& w,
 
  307       bool update_orthogonal_basis(
const opts_t& opts,
 
  308                                    scalar_t& r_max_0, 
const DistM_t& S,
 
  309                                    DistM_t& Q, 
int d, 
int dd,
 
  310                                    bool untouched, 
int L);
 
  311       void reduce_local_samples(
const DistSamples<scalar_t>& RS,
 
  312                                 WorkCompressMPI<scalar_t>& w,
 
  313                                 int dd, 
bool was_compressed);
 
  314       void communicate_child_data(WorkCompressMPI<scalar_t>& w);
 
  315       void notify_inactives_J(WorkCompressMPI<scalar_t>& w);
 
  316       void notify_inactives_J(WorkCompressMPIANN<scalar_t>& w);
 
  317       void notify_inactives_states(WorkCompressMPI<scalar_t>& w);
 
  319       void compress_level_original(DistSamples<scalar_t>& RS,
 
  321                                    WorkCompressMPI<scalar_t>& w,
 
  322                                    int dd, 
int lvl) 
override;
 
  323       void compress_level_stable(DistSamples<scalar_t>& RS,
 
  325                                  WorkCompressMPI<scalar_t>& w,
 
  326                                  int d, 
int dd, 
int lvl) 
override;
 
  327       void extract_level(
const delemw_t& Aelem, 
const opts_t& opts,
 
  328                          WorkCompressMPI<scalar_t>& w, 
int lvl);
 
  329       void extract_level(
const delem_blocks_t& Aelem, 
const opts_t& opts,
 
  330                          WorkCompressMPI<scalar_t>& w, 
int lvl);
 
  331       void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
 
  332                                   std::vector<std::vector<std::size_t>>& J,
 
  333                                   WorkCompressMPI<scalar_t>& w,
 
  334                                   int& 
self, 
int lvl) 
override;
 
  335       void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
 
  336                                   std::vector<std::vector<std::size_t>>& J,
 
  337                                   std::vector<DistMW_t>& B,
 
  339                                   WorkCompressMPI<scalar_t>& w,
 
  340                                   int& 
self, 
int lvl) 
override;
 
  341       void allgather_extraction_indices(std::vector<std::vector<std::size_t>>& lI,
 
  342                                         std::vector<std::vector<std::size_t>>& lJ,
 
  343                                         std::vector<std::vector<std::size_t>>& I,
 
  344                                         std::vector<std::vector<std::size_t>>& J,
 
  345                                         int& before, 
int self, 
int& after);
 
  346       void extract_D_B(
const delemw_t& Aelem,
 
  348                        WorkCompressMPI<scalar_t>& w, 
int lvl) 
override;
 
  350       void factor_recursive(WorkFactorMPI<scalar_t>& w,
 
  352                             bool isroot, 
bool partial) 
override;
 
  354       void solve_fwd(
const DistSubLeaf<scalar_t>& b,
 
  355                      WorkSolveMPI<scalar_t>& w,
 
  356                      bool partial, 
bool isroot) 
const override;
 
  357       void solve_bwd(DistSubLeaf<scalar_t>& x,
 
  358                      WorkSolveMPI<scalar_t>& w, 
bool isroot) 
const override;
 
  360       void apply_fwd(
const DistSubLeaf<scalar_t>& B,
 
  361                      WorkApplyMPI<scalar_t>& w,
 
  362                      bool isroot, 
long long int flops) 
const override;
 
  363       void apply_bwd(
const DistSubLeaf<scalar_t>& B, scalar_t beta,
 
  364                      DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
 
  365                      bool isroot, 
long long int flops) 
const override;
 
  366       void applyT_fwd(
const DistSubLeaf<scalar_t>& B,
 
  367                       WorkApplyMPI<scalar_t>& w,
 
  368                       bool isroot, 
long long int flops) 
const override;
 
  369       void applyT_bwd(
const DistSubLeaf<scalar_t>& B, scalar_t beta,
 
  370                       DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
 
  371                       bool isroot, 
long long int flops) 
const override;
 
  373       void extract_fwd(WorkExtractMPI<scalar_t>& w, 
const BLACSGrid* lg,
 
  374                        bool odiag) 
const override;
 
  375       void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
 
  377                        WorkExtractMPI<scalar_t>& w) 
const override;
 
  378       void triplets_to_DistM(std::vector<Triplet<scalar_t>>& triplets,
 
  380       void extract_fwd(WorkExtractBlocksMPI<scalar_t>& w,
 
  382                        std::vector<bool>& odiag) 
const override;
 
  383       void extract_bwd(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
 
  385                        WorkExtractBlocksMPI<scalar_t>& w) 
const override;
 
  386       void triplets_to_DistM(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
 
  387                              std::vector<DistM_t>& B) 
const;
 
  389       void redistribute_to_tree_to_buffers(
const DistM_t& A,
 
  390                                            std::size_t Arlo, std::size_t Aclo,
 
  391                                            std::vector<std::vector<scalar_t>>& sbuf,
 
  392                                            int dest=0) 
override;
 
  393       void redistribute_to_tree_from_buffers(
const DistM_t& A,
 
  394                                              std::size_t rlo, std::size_t clo,
 
  395                                              std::vector<scalar_t*>& pbuf)
 
  397       void delete_redistributed_input() 
override;
 
  399       void apply_UV_big(DistSubLeaf<scalar_t>& Theta, DistM_t& Uop,
 
  400                         DistSubLeaf<scalar_t>& Phi, DistM_t& Vop,
 
  401                         long long int& flops) 
const override;
 
  403       static int Pl(std::size_t n, std::size_t nl, std::size_t nr, 
int P) {
 
  405           (1, std::min(
int(std::round(
float(P) * nl / n)), P-1));
 
  407       static int Pr(std::size_t n, std::size_t nl, std::size_t nr, 
int P) {
 
  408         return std::max(1, P - Pl(n, nl, nr, P));
 
  411         return Pl(this->
rows(), child(0)->
rows(),
 
  412                   child(1)->
rows(), Ptotal());
 
  415         return Pr(this->
rows(), child(0)->
rows(),
 
  416                   child(1)->
rows(), Ptotal());
 
  419       template<
typename T> 
friend 
  420       void apply_HSS(
Trans ta, 
const HSSMatrixMPI<T>& a,
 
  421                      const DistributedMatrix<T>& b, T beta,
 
  422                      DistributedMatrix<T>& c);
 
  423       friend class DistSamples<scalar_t>;
 
This file contains the HSSMatrix class definition as well as implementations for a number of it's mem...
Definitions of several kernel functions, and helper routines. Also provides driver routines for kerne...
Contains some simple C++ MPI wrapper utilities.
This is a small wrapper class around a BLACS grid and a BLACS context.
Definition: BLACSGrid.hpp:66
int npactives() const
Definition: BLACSGrid.hpp:257
int P() const
Definition: BLACSGrid.hpp:251
const MPIComm & Comm() const
Definition: BLACSGrid.hpp:196
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition: DenseMatrix.hpp:1015
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition: DenseMatrix.hpp:138
Definition: DistributedMatrix.hpp:733
2D block cyclicly distributed matrix, as used by ScaLAPACK.
Definition: DistributedMatrix.hpp:84
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition: HSSMatrixBase.hpp:83
const HSSMatrixBase< scalar_t > & child(int c) const
Definition: HSSMatrixBase.hpp:188
std::size_t rows() const override
Definition: HSSMatrixBase.hpp:163
Distributed memory implementation of the HSS (Hierarchically Semi-Separable) matrix format.
Definition: HSSMatrixMPI.hpp:68
void mult(Trans op, const DistM_t &x, DistM_t &y) const override
void shift(scalar_t sigma) override
std::size_t nonzeros() const override
std::size_t levels() const override
void solve(DistM_t &b) const override
std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const override
void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const override
std::size_t rank() const override
std::size_t memory() const override
Class containing several options for the HSS code and data-structures.
Definition: HSSOptions.hpp:152
Definition: HSSExtraMPI.hpp:134
Wrapper class around an MPI_Comm object.
Definition: MPIWrapper.hpp:194
MPI_Comm comm() const
Definition: MPIWrapper.hpp:261
Representation of a kernel matrix.
Definition: Kernel.hpp:73
The cluster tree, or partition tree that represents the partitioning of the rows or columns of a hier...
Definition: ClusterTree.hpp:67
Definition: StrumpackOptions.hpp:43
Trans
Definition: DenseMatrix.hpp:51