34 #ifndef HSS_MATRIX_BASE_HPP
35 #define HSS_MATRIX_BASE_HPP
45 #include "misc/Triplet.hpp"
47 #include "HSSExtra.hpp"
49 #if defined(STRUMPACK_USE_MPI)
51 #include "HSSExtraMPI.hpp"
58 #ifndef DOXYGEN_SHOULD_SKIP_THIS
59 template<
typename scalar_t>
class HSSMatrix;
60 #if defined(STRUMPACK_USE_MPI)
61 template<
typename scalar_t>
class HSSMatrixMPI;
62 template<
typename scalar_t>
class DistSubLeaf;
63 template<
typename scalar_t>
class DistSamples;
84 using real_t =
typename RealType<scalar_t>::value_type;
87 using elem_t =
typename std::function
88 <void(
const std::vector<std::size_t>& I,
89 const std::vector<std::size_t>& J,
DenseM_t& B)>;
91 #if defined(STRUMPACK_USE_MPI)
94 using delem_t =
typename std::function
95 <void(
const std::vector<std::size_t>& I,
96 const std::vector<std::size_t>& J,
DistM_t& B)>;
147 virtual std::unique_ptr<HSSMatrixBase<scalar_t>>
clone()
const = 0;
155 std::pair<std::size_t,std::size_t>
dims()
const {
156 return std::make_pair(rows_, cols_);
163 std::size_t
rows()
const override {
return rows_; }
169 std::size_t
cols()
const override {
return cols_; }
175 bool leaf()
const {
return ch_.empty(); }
177 virtual std::size_t factor_nonzeros()
const;
189 assert(c>=0 && c<
int(ch_.size()));
return *(ch_[c]);
202 assert(c>=0 && c<
int(ch_.size()));
return *(ch_[c]);
262 std::size_t coff=0)
const = 0;
275 #ifndef DOXYGEN_SHOULD_SKIP_THIS
276 virtual void delete_trailing_block() {
if (ch_.size()==2) ch_.resize(1); }
277 virtual void reset() {
279 U_rank_ = U_rows_ = V_rank_ = V_rows_ = 0;
280 for (
auto& c : ch_) c->reset();
291 virtual void shift(scalar_t sigma)
override = 0;
299 virtual void draw(std::ostream& of,
301 std::size_t clo)
const {}
303 #if defined(STRUMPACK_USE_MPI)
304 virtual void forward_solve(WorkSolveMPI<scalar_t>& w,
305 const DistM_t& b,
bool partial)
const;
306 virtual void backward_solve(WorkSolveMPI<scalar_t>& w,
309 virtual const BLACSGrid* grid()
const {
return nullptr; }
311 return active() ? local_grid :
nullptr;
313 virtual const BLACSGrid* grid_local()
const {
return nullptr; }
314 virtual int Ptotal()
const {
return 1; }
315 virtual int Pactive()
const {
return 1; }
317 virtual void to_block_row(
const DistM_t& A, DenseM_t& sub_A,
318 DistM_t& leaf_A)
const;
319 virtual void allocate_block_row(
int d, DenseM_t& sub_A,
320 DistM_t& leaf_A)
const;
321 virtual void from_block_row(DistM_t& A,
const DenseM_t& sub_A,
322 const DistM_t& leaf_A,
323 const BLACSGrid* lg)
const;
327 std::size_t rows_, cols_;
330 std::vector<std::unique_ptr<HSSMatrixBase<scalar_t>>> ch_;
331 State U_state_, V_state_;
332 int openmp_task_depth_;
335 int U_rank_ = 0, U_rows_ = 0, V_rank_ = 0, V_rows_ = 0;
341 HSSFactors<scalar_t> ULV_;
342 #if defined(STRUMPACK_USE_MPI)
343 HSSFactorsMPI<scalar_t> ULV_mpi_;
346 virtual std::size_t U_rank()
const {
return U_rank_; }
347 virtual std::size_t V_rank()
const {
return V_rank_; }
348 virtual std::size_t U_rows()
const {
return U_rows_; }
349 virtual std::size_t V_rows()
const {
return V_rows_; }
352 compress_recursive_original(DenseM_t& Rr, DenseM_t& Rc,
353 DenseM_t& Sr, DenseM_t& Sc,
354 const elem_t& Aelem,
const opts_t& opts,
355 WorkCompress<scalar_t>& w,
356 int dd,
int depth) {}
358 compress_recursive_stable(DenseM_t& Rr, DenseM_t& Rc,
359 DenseM_t& Sr, DenseM_t& Sc,
360 const elem_t& Aelem,
const opts_t& opts,
361 WorkCompress<scalar_t>& w,
362 int d,
int dd,
int depth) {}
364 compress_level_original(DenseM_t& Rr, DenseM_t& Rc,
365 DenseM_t& Sr, DenseM_t& Sc,
366 const opts_t& opts, WorkCompress<scalar_t>& w,
367 int dd,
int lvl,
int depth) {}
369 compress_level_stable(DenseM_t& Rr, DenseM_t& Rc,
370 DenseM_t& Sr, DenseM_t& Sc,
371 const opts_t& opts, WorkCompress<scalar_t>& w,
372 int d,
int dd,
int lvl,
int depth) {}
374 compress_recursive_ann(DenseMatrix<std::uint32_t>& ann,
375 DenseMatrix<real_t>& scores,
376 const elem_t& Aelem,
const opts_t& opts,
377 WorkCompressANN<scalar_t>& w,
int depth) {}
380 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
381 std::vector<std::vector<std::size_t>>& J,
382 const std::pair<std::size_t,std::size_t>& off,
383 WorkCompress<scalar_t>& w,
384 int&
self,
int lvl) {}
387 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
388 std::vector<std::vector<std::size_t>>& J,
389 std::vector<DenseM_t*>& B,
390 const std::pair<std::size_t,std::size_t>& off,
391 WorkCompress<scalar_t>& w,
392 int&
self,
int lvl) {}
393 virtual void extract_D_B(
const elem_t& Aelem,
const opts_t& opts,
394 WorkCompress<scalar_t>& w,
int lvl) {}
396 virtual void factor_recursive(WorkFactor<scalar_t>& w,
397 bool isroot,
bool partial,
400 virtual void apply_fwd(
const DenseM_t& b, WorkApply<scalar_t>& w,
401 bool isroot,
int depth,
402 std::atomic<long long int>& flops)
const {}
403 virtual void apply_bwd(
const DenseM_t& b, scalar_t beta,
404 DenseM_t& c, WorkApply<scalar_t>& w,
405 bool isroot,
int depth,
406 std::atomic<long long int>& flops)
const {}
407 virtual void applyT_fwd(
const DenseM_t& b, WorkApply<scalar_t>& w,
408 bool isroot,
int depth,
409 std::atomic<long long int>& flops)
const {}
410 virtual void applyT_bwd(
const DenseM_t& b, scalar_t beta,
411 DenseM_t& c, WorkApply<scalar_t>& w,
412 bool isroot,
int depth,
413 std::atomic<long long int>& flops)
const {}
415 virtual void forward_solve(WorkSolve<scalar_t>& w,
416 const DenseMatrix<scalar_t>& b,
417 bool partial)
const {}
418 virtual void backward_solve(WorkSolve<scalar_t>& w,
419 DenseMatrix<scalar_t>& b)
const {}
420 virtual void solve_fwd(
const DenseM_t& b,
421 WorkSolve<scalar_t>& w,
bool partial,
422 bool isroot,
int depth)
const {}
423 virtual void solve_bwd(DenseM_t& x, WorkSolve<scalar_t>& w,
424 bool isroot,
int depth)
const {}
426 virtual void extract_fwd(WorkExtract<scalar_t>& w,
427 bool odiag,
int depth)
const {}
428 virtual void extract_bwd(DenseMatrix<scalar_t>& B,
429 WorkExtract<scalar_t>& w,
431 virtual void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
432 WorkExtract<scalar_t>& w,
int depth)
const {}
434 virtual void apply_UV_big(DenseM_t& Theta, DenseM_t& Uop,
435 DenseM_t& Phi, DenseM_t& Vop,
436 const std::pair<std::size_t,std::size_t>& offset,
438 std::atomic<long long int>& flops)
const {}
439 virtual void apply_UtVt_big(
const DenseM_t& A, DenseM_t& UtA,
441 const std::pair<std::size_t, std::size_t>& offset,
443 std::atomic<long long int>& flops)
const {}
445 virtual void dense_recursive(DenseM_t& A, WorkDense<scalar_t>& w,
446 bool isroot,
int depth)
const {}
448 virtual void read(std::ifstream& os) {
449 std::cerr <<
"ERROR read_HSS_node not implemented" << std::endl;
451 virtual void write(std::ofstream& os)
const {
452 std::cerr <<
"ERROR write_HSS_node not implemented" << std::endl;
455 friend class HSSMatrix<scalar_t>;
457 #if defined(STRUMPACK_USE_MPI)
458 using delemw_t =
typename std::function
459 <void(
const std::vector<std::size_t>& I,
460 const std::vector<std::size_t>& J,
461 DistM_t& B, DistM_t& A,
462 std::size_t rlo, std::size_t clo,
467 compress_recursive_original(DistSamples<scalar_t>& RS,
468 const delemw_t& Aelem,
470 WorkCompressMPI<scalar_t>& w,
int dd);
472 compress_recursive_stable(DistSamples<scalar_t>& RS,
473 const delemw_t& Aelem,
475 WorkCompressMPI<scalar_t>& w,
int d,
int dd);
477 compress_level_original(DistSamples<scalar_t>& RS,
const opts_t& opts,
478 WorkCompressMPI<scalar_t>& w,
int dd,
int lvl);
480 compress_level_stable(DistSamples<scalar_t>& RS,
const opts_t& opts,
481 WorkCompressMPI<scalar_t>& w,
482 int d,
int dd,
int lvl);
484 compress_recursive_ann(DenseMatrix<std::uint32_t>& ann,
485 DenseMatrix<real_t>& scores,
486 const delemw_t& Aelem,
487 WorkCompressMPIANN<scalar_t>& w,
488 const opts_t& opts,
const BLACSGrid* lg);
491 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
492 std::vector<std::vector<std::size_t>>& J,
493 WorkCompressMPI<scalar_t>& w,
496 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
497 std::vector<std::vector<std::size_t>>& J,
498 std::vector<DistMW_t>& B,
500 WorkCompressMPI<scalar_t>& w,
502 virtual void extract_D_B(
const delemw_t& Aelem,
const BLACSGrid* lg,
504 WorkCompressMPI<scalar_t>& w,
int lvl);
506 virtual void apply_fwd(
const DistSubLeaf<scalar_t>& B,
507 WorkApplyMPI<scalar_t>& w,
508 bool isroot,
long long int flops)
const;
509 virtual void apply_bwd(
const DistSubLeaf<scalar_t>& B, scalar_t beta,
510 DistSubLeaf<scalar_t>& C,
511 WorkApplyMPI<scalar_t>& w,
512 bool isroot,
long long int flops)
const;
513 virtual void applyT_fwd(
const DistSubLeaf<scalar_t>& B,
514 WorkApplyMPI<scalar_t>& w,
515 bool isroot,
long long int flops)
const;
516 virtual void applyT_bwd(
const DistSubLeaf<scalar_t>& B, scalar_t beta,
517 DistSubLeaf<scalar_t>& C,
518 WorkApplyMPI<scalar_t>& w,
519 bool isroot,
long long int flops)
const;
521 virtual void factor_recursive(WorkFactorMPI<scalar_t>& w,
522 const BLACSGrid* lg,
bool isroot,
525 virtual void solve_fwd(
const DistSubLeaf<scalar_t>& b,
526 WorkSolveMPI<scalar_t>& w,
527 bool partial,
bool isroot)
const;
528 virtual void solve_bwd(DistSubLeaf<scalar_t>& x,
529 WorkSolveMPI<scalar_t>& w,
bool isroot)
const;
531 virtual void extract_fwd(WorkExtractMPI<scalar_t>& w,
532 const BLACSGrid* lg,
bool odiag)
const;
533 virtual void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
535 WorkExtractMPI<scalar_t>& w)
const;
536 virtual void extract_fwd(WorkExtractBlocksMPI<scalar_t>& w,
538 std::vector<bool>& odiag)
const;
539 virtual void extract_bwd(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
541 WorkExtractBlocksMPI<scalar_t>& w)
const;
543 virtual void apply_UV_big(DistSubLeaf<scalar_t>& Theta, DistM_t& Uop,
544 DistSubLeaf<scalar_t>& Phi, DistM_t& Vop,
545 long long int& flops)
const;
548 redistribute_to_tree_to_buffers(
const DistM_t& A,
549 std::size_t Arlo, std::size_t Aclo,
550 std::vector<std::vector<scalar_t>>& sbuf,
553 redistribute_to_tree_from_buffers(
const DistM_t& A,
554 std::size_t Arlo, std::size_t Aclo,
555 std::vector<scalar_t*>& pbuf);
556 virtual void delete_redistributed_input();
558 friend class HSSMatrixMPI<scalar_t>;
Contains the DenseMatrix and DenseMatrixWrapper classes, simple wrappers around BLAS/LAPACK style den...
Contains the DistributedMatrix and DistributedMatrixWrapper classes, wrappers around ScaLAPACK/PBLAS ...
This file contains the HSSMatrixMPI class definition as well as implementations for a number of it's ...
Contains the HSSOptions class as well as general routines for HSS options.
Contains the structured matrix interfaces.
This is a small wrapper class around a BLACS grid and a BLACS context.
Definition: BLACSGrid.hpp:66
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition: DenseMatrix.hpp:1015
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition: DenseMatrix.hpp:138
Definition: DistributedMatrix.hpp:733
2D block cyclicly distributed matrix, as used by ScaLAPACK.
Definition: DistributedMatrix.hpp:84
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition: HSSMatrixBase.hpp:83
void set_openmp_task_depth(int depth)
Definition: HSSMatrixBase.hpp:273
bool active() const
Definition: HSSMatrixBase.hpp:239
HSSMatrixBase(HSSMatrixBase &&h)=default
HSSMatrixBase(const HSSMatrixBase< scalar_t > &other)
virtual std::size_t levels() const =0
virtual void draw(std::ostream &of, std::size_t rlo, std::size_t clo) const
Definition: HSSMatrixBase.hpp:299
const HSSMatrixBase< scalar_t > & child(int c) const
Definition: HSSMatrixBase.hpp:188
std::size_t rows() const override
Definition: HSSMatrixBase.hpp:163
std::pair< std::size_t, std::size_t > dims() const
Definition: HSSMatrixBase.hpp:155
HSSMatrixBase< scalar_t > & operator=(const HSSMatrixBase< scalar_t > &other)
virtual void shift(scalar_t sigma) override=0
HSSMatrixBase(std::size_t m, std::size_t n, bool active)
virtual ~HSSMatrixBase()=default
bool is_compressed() const
Definition: HSSMatrixBase.hpp:213
virtual void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const =0
bool leaf() const
Definition: HSSMatrixBase.hpp:175
HSSMatrixBase & operator=(HSSMatrixBase &&h)=default
virtual std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const =0
HSSMatrixBase< scalar_t > & child(int c)
Definition: HSSMatrixBase.hpp:201
std::size_t cols() const override
Definition: HSSMatrixBase.hpp:169
bool is_untouched() const
Definition: HSSMatrixBase.hpp:228
Class containing several options for the HSS code and data-structures.
Definition: HSSOptions.hpp:152
Class to represent a structured matrix. This is the abstract base class for several types of structur...
Definition: StructuredMatrix.hpp:209
State
Definition: HSSExtra.hpp:46
Definition: StrumpackOptions.hpp:43