37 #ifndef HSS_MATRIX_HPP
38 #define HSS_MATRIX_HPP
44 #include "HSSBasisID.hpp"
46 #include "HSSExtra.hpp"
77 using real_t =
typename RealType<scalar_t>::value_type;
80 using elem_t =
typename std::function
81 <void(
const std::vector<std::size_t>& I,
82 const std::vector<std::size_t>& J,
DenseM_t& B)>;
83 using mult_t =
typename std::function
184 std::unique_ptr<HSSMatrixBase<scalar_t>>
clone()
const override;
222 void compress(
const DenseM_t& A,
const opts_t& opts);
270 <
void(DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc)>& Amult,
272 <
void(
const std::vector<std::size_t>& I,
273 const std::vector<std::size_t>& J, DenseM_t& B)>& Aelem,
306 <
void(
const std::vector<std::size_t>& I,
307 const std::vector<std::size_t>& J, DenseM_t& B)>& Aelem,
314 void reset()
override;
367 const DenseM_t& b,
bool partial)
const override;
385 WorkSolve<scalar_t>& w, DenseM_t& x)
const override;
395 DenseM_t
apply(
const DenseM_t& b)
const;
406 DenseM_t
applyC(
const DenseM_t& b)
const;
418 scalar_t
get(std::size_t i, std::size_t j)
const;
431 (
const std::vector<std::size_t>& I,
432 const std::vector<std::size_t>& J)
const;
448 (
const std::vector<std::size_t>& I,
const std::vector<std::size_t>& J,
451 #ifndef DOXYGEN_SHOULD_SKIP_THIS
454 DenseM_t& DUB01, DenseM_t& Phi)
const;
455 void Schur_product_direct
457 const DenseM_t& Theta,
const DenseM_t& DUB01,
458 const DenseM_t& Phi,
const DenseM_t&_ThetaVhatC_or_VhatCPhiC,
459 const DenseM_t& R, DenseM_t& Sr, DenseM_t& Sc)
const;
460 void Schur_product_indirect
462 const DenseM_t& R1,
const DenseM_t& R2,
const DenseM_t& Sr2,
463 const DenseM_t& Sc2, DenseM_t& Sr, DenseM_t& Sc)
const;
464 void delete_trailing_block()
override;
465 #endif // DOXYGEN_SHOULD_SKIP_THIS
467 std::size_t
rank()
const override;
468 std::size_t
memory()
const override;
469 std::size_t
nonzeros()
const override;
470 std::size_t
levels()
const override;
472 (std::ostream& out=std::cout,
473 std::size_t roff=0, std::size_t coff=0)
const override;
481 DenseM_t
dense()
const;
483 void shift(scalar_t sigma)
override;
486 (std::ostream& of, std::size_t rlo=0, std::size_t clo=0)
const override;
490 (std::size_t m, std::size_t n,
const opts_t& opts,
bool active);
493 HSSBasisID<scalar_t> _U;
494 HSSBasisID<scalar_t> _V;
499 void compress_original(
const DenseM_t& A,
const opts_t& opts);
500 void compress_original
501 (
const mult_t& Amult,
const elem_t& Aelem,
const opts_t& opts);
502 void compress_stable(
const DenseM_t& A,
const opts_t& opts);
504 (
const mult_t& Amult,
const elem_t& Aelem,
const opts_t& opts);
505 void compress_hard_restart(
const DenseM_t& A,
const opts_t& opts);
506 void compress_hard_restart
507 (
const mult_t& Amult,
const elem_t& Aelem,
const opts_t& opts);
509 void compress_recursive_original
510 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
511 const elem_t& Aelem,
const opts_t& opts,
512 WorkCompress<scalar_t>& w,
int dd,
int depth)
override;
513 void compress_recursive_stable
514 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
515 const elem_t& Aelem,
const opts_t& opts,
516 WorkCompress<scalar_t>& w,
int d,
int dd,
int depth)
override;
517 void compute_local_samples
518 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
519 WorkCompress<scalar_t>& w,
int d0,
int d,
int depth);
520 bool compute_U_V_bases
521 (DenseM_t& Sr, DenseM_t& Sc,
const opts_t& opts,
522 WorkCompress<scalar_t>& w,
int d,
int depth);
523 void compute_U_basis_stable
524 (DenseM_t& Sr,
const opts_t& opts,
525 WorkCompress<scalar_t>& w,
int d,
int dd,
int depth);
526 void compute_V_basis_stable
527 (DenseM_t& Sc,
const opts_t& opts,
528 WorkCompress<scalar_t>& w,
int d,
int dd,
int depth);
529 void reduce_local_samples
530 (DenseM_t& Rr, DenseM_t& Rc, WorkCompress<scalar_t>& w,
531 int d0,
int d,
int depth);
532 bool update_orthogonal_basis
533 (
const opts_t& opts, scalar_t& r_max_0,
534 const DenseM_t& S, DenseM_t& Q,
int d,
int dd,
535 bool untouched,
int L,
int depth);
536 void set_U_full_rank(WorkCompress<scalar_t>& w);
537 void set_V_full_rank(WorkCompress<scalar_t>& w);
539 void compress_level_original
540 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
541 const opts_t& opts, WorkCompress<scalar_t>& w,
542 int dd,
int lvl,
int depth)
override;
543 void compress_level_stable
544 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
545 const opts_t& opts, WorkCompress<scalar_t>& w,
546 int d,
int dd,
int lvl,
int depth)
override;
547 void get_extraction_indices
548 (std::vector<std::vector<std::size_t>>& I,
549 std::vector<std::vector<std::size_t>>& J,
550 const std::pair<std::size_t,std::size_t>& off,
551 WorkCompress<scalar_t>& w,
int&
self,
int lvl)
override;
552 void get_extraction_indices
553 (std::vector<std::vector<std::size_t>>& I,
554 std::vector<std::vector<std::size_t>>& J, std::vector<DenseM_t*>& B,
555 const std::pair<std::size_t,std::size_t>& off,
556 WorkCompress<scalar_t>& w,
int&
self,
int lvl)
override;
558 (
const elem_t& Aelem,
const opts_t& opts,
559 WorkCompress<scalar_t>& w,
int lvl)
override;
563 void compress_recursive_ann
565 const elem_t& Aelem,
const opts_t& opts,
566 WorkCompressANN<scalar_t>& w,
int depth)
override;
567 void compute_local_samples_ann
569 WorkCompressANN<scalar_t>& w,
const elem_t& Aelem,
const opts_t& opts);
570 bool compute_U_V_bases_ann
571 (DenseM_t& S,
const opts_t& opts,
572 WorkCompressANN<scalar_t>& w,
int depth);
574 void factor_recursive
576 bool isroot,
bool partial,
int depth)
const override;
579 (
const DenseM_t& b, WorkApply<scalar_t>& w,
bool isroot,
580 int depth, std::atomic<long long int>& flops)
const override;
582 (
const DenseM_t& b, scalar_t beta, DenseM_t& c,
583 WorkApply<scalar_t>& w,
bool isroot,
int depth,
584 std::atomic<long long int>& flops)
const override;
586 (
const DenseM_t& b, WorkApply<scalar_t>& w,
bool isroot,
587 int depth, std::atomic<long long int>& flops)
const override;
589 (
const DenseM_t& b, scalar_t beta, DenseM_t& c,
590 WorkApply<scalar_t>& w,
bool isroot,
int depth,
591 std::atomic<long long int>& flops)
const override;
595 WorkSolve<scalar_t>& w,
596 bool partial,
bool isroot,
int depth)
const override;
599 WorkSolve<scalar_t>& w,
600 bool isroot,
int depth)
const override;
603 (WorkExtract<scalar_t>& w,
bool odiag,
int depth)
const override;
605 (DenseM_t& B, WorkExtract<scalar_t>& w,
int depth)
const override;
607 (std::vector<Triplet<scalar_t>>& triplets,
608 WorkExtract<scalar_t>& w,
int depth)
const override;
609 void extract_bwd_internal(WorkExtract<scalar_t>& w,
int depth)
const;
612 (DenseM_t& Theta, DenseM_t& Uop, DenseM_t& Phi,
613 DenseM_t& Vop,
const std::pair<std::size_t, std::size_t>& offset,
614 int depth, std::atomic<long long int>& flops)
const override;
616 (
const DenseM_t& A, DenseM_t& UtA, DenseM_t& VtA,
617 const std::pair<std::size_t, std::size_t>& offset,
618 int depth, std::atomic<long long int>& flops)
const override;
621 (DenseM_t& A, WorkDense<scalar_t>& w,
622 bool isroot,
int depth)
const override;
627 template<
typename T>
friend void apply_HSS
634 template<
typename T>
friend void draw
647 template<
typename scalar_t>
648 void draw(
const HSSMatrix<scalar_t>& H,
const std::string& name);
661 template<
typename scalar_t>
void apply_HSS
668 #endif // HSS_MATRIX_HPP