37 #ifndef HSS_MATRIX_HPP
38 #define HSS_MATRIX_HPP
45 #include "HSSBasisID.hpp"
47 #include "HSSExtra.hpp"
78 using real_t =
typename RealType<scalar_t>::value_type;
81 using elem_t =
typename std::function
82 <void(
const std::vector<std::size_t>& I,
83 const std::vector<std::size_t>& J,
DenseM_t& B)>;
84 using mult_t =
typename std::function
183 std::unique_ptr<HSSMatrixBase<scalar_t>>
clone()
const override;
221 void compress(
const DenseM_t& A,
const opts_t& opts);
269 <
void(DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc)>& Amult,
271 <
void(
const std::vector<std::size_t>& I,
272 const std::vector<std::size_t>& J, DenseM_t& B)>& Aelem,
305 <
void(
const std::vector<std::size_t>& I,
306 const std::vector<std::size_t>& J, DenseM_t& B)>& Aelem,
313 void reset()
override;
366 const DenseM_t& b,
bool partial)
const override;
384 WorkSolve<scalar_t>& w, DenseM_t& x)
const override;
394 DenseM_t
apply(
const DenseM_t& b)
const;
405 DenseM_t
applyC(
const DenseM_t& b)
const;
417 scalar_t
get(std::size_t i, std::size_t j)
const;
430 (
const std::vector<std::size_t>& I,
431 const std::vector<std::size_t>& J)
const;
447 (
const std::vector<std::size_t>& I,
const std::vector<std::size_t>& J,
450 #ifndef DOXYGEN_SHOULD_SKIP_THIS
453 DenseM_t& DUB01, DenseM_t& Phi)
const;
454 void Schur_product_direct
456 const DenseM_t& Theta,
const DenseM_t& DUB01,
457 const DenseM_t& Phi,
const DenseM_t&_ThetaVhatC_or_VhatCPhiC,
458 const DenseM_t& R, DenseM_t& Sr, DenseM_t& Sc)
const;
459 void Schur_product_indirect
461 const DenseM_t& R1,
const DenseM_t& R2,
const DenseM_t& Sr2,
462 const DenseM_t& Sc2, DenseM_t& Sr, DenseM_t& Sc)
const;
463 void delete_trailing_block()
override;
464 #endif // DOXYGEN_SHOULD_SKIP_THIS
466 std::size_t
rank()
const override;
467 std::size_t
memory()
const override;
468 std::size_t
nonzeros()
const override;
469 std::size_t
levels()
const override;
471 (std::ostream& out=std::cout,
472 std::size_t roff=0, std::size_t coff=0)
const override;
480 DenseM_t
dense()
const;
482 void shift(scalar_t sigma)
override;
484 void draw(std::ostream& of,
485 std::size_t rlo=0, std::size_t clo=0)
const override;
493 void write(
const std::string& fname)
const;
505 const opts_t& opts,
bool active);
509 HSSBasisID<scalar_t> _U, _V;
510 DenseM_t _D, _B01, _B10;
512 void compress_original(
const DenseM_t& A,
const opts_t& opts);
513 void compress_original
514 (
const mult_t& Amult,
const elem_t& Aelem,
const opts_t& opts);
515 void compress_stable(
const DenseM_t& A,
const opts_t& opts);
517 (
const mult_t& Amult,
const elem_t& Aelem,
const opts_t& opts);
518 void compress_hard_restart(
const DenseM_t& A,
const opts_t& opts);
519 void compress_hard_restart
520 (
const mult_t& Amult,
const elem_t& Aelem,
const opts_t& opts);
522 void compress_recursive_original
523 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
524 const elem_t& Aelem,
const opts_t& opts,
525 WorkCompress<scalar_t>& w,
int dd,
int depth)
override;
526 void compress_recursive_stable
527 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
528 const elem_t& Aelem,
const opts_t& opts,
529 WorkCompress<scalar_t>& w,
int d,
int dd,
int depth)
override;
530 void compute_local_samples
531 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
532 WorkCompress<scalar_t>& w,
int d0,
int d,
int depth);
533 bool compute_U_V_bases
534 (DenseM_t& Sr, DenseM_t& Sc,
const opts_t& opts,
535 WorkCompress<scalar_t>& w,
int d,
int depth);
536 void compute_U_basis_stable
537 (DenseM_t& Sr,
const opts_t& opts,
538 WorkCompress<scalar_t>& w,
int d,
int dd,
int depth);
539 void compute_V_basis_stable
540 (DenseM_t& Sc,
const opts_t& opts,
541 WorkCompress<scalar_t>& w,
int d,
int dd,
int depth);
542 void reduce_local_samples
543 (DenseM_t& Rr, DenseM_t& Rc, WorkCompress<scalar_t>& w,
544 int d0,
int d,
int depth);
545 bool update_orthogonal_basis
546 (
const opts_t& opts, scalar_t& r_max_0,
547 const DenseM_t& S, DenseM_t& Q,
int d,
int dd,
548 bool untouched,
int L,
int depth);
549 void set_U_full_rank(WorkCompress<scalar_t>& w);
550 void set_V_full_rank(WorkCompress<scalar_t>& w);
552 void compress_level_original
553 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
554 const opts_t& opts, WorkCompress<scalar_t>& w,
555 int dd,
int lvl,
int depth)
override;
556 void compress_level_stable
557 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
558 const opts_t& opts, WorkCompress<scalar_t>& w,
559 int d,
int dd,
int lvl,
int depth)
override;
560 void get_extraction_indices
561 (std::vector<std::vector<std::size_t>>& I,
562 std::vector<std::vector<std::size_t>>& J,
563 const std::pair<std::size_t,std::size_t>& off,
564 WorkCompress<scalar_t>& w,
int&
self,
int lvl)
override;
565 void get_extraction_indices
566 (std::vector<std::vector<std::size_t>>& I,
567 std::vector<std::vector<std::size_t>>& J, std::vector<DenseM_t*>& B,
568 const std::pair<std::size_t,std::size_t>& off,
569 WorkCompress<scalar_t>& w,
int&
self,
int lvl)
override;
571 (
const elem_t& Aelem,
const opts_t& opts,
572 WorkCompress<scalar_t>& w,
int lvl)
override;
576 void compress_recursive_ann
578 const elem_t& Aelem,
const opts_t& opts,
579 WorkCompressANN<scalar_t>& w,
int depth)
override;
580 void compute_local_samples_ann
582 WorkCompressANN<scalar_t>& w,
const elem_t& Aelem,
const opts_t& opts);
583 bool compute_U_V_bases_ann
584 (DenseM_t& S,
const opts_t& opts,
585 WorkCompressANN<scalar_t>& w,
int depth);
587 void factor_recursive
589 bool isroot,
bool partial,
int depth)
const override;
592 (
const DenseM_t& b, WorkApply<scalar_t>& w,
bool isroot,
593 int depth, std::atomic<long long int>& flops)
const override;
595 (
const DenseM_t& b, scalar_t beta, DenseM_t& c,
596 WorkApply<scalar_t>& w,
bool isroot,
int depth,
597 std::atomic<long long int>& flops)
const override;
599 (
const DenseM_t& b, WorkApply<scalar_t>& w,
bool isroot,
600 int depth, std::atomic<long long int>& flops)
const override;
602 (
const DenseM_t& b, scalar_t beta, DenseM_t& c,
603 WorkApply<scalar_t>& w,
bool isroot,
int depth,
604 std::atomic<long long int>& flops)
const override;
608 WorkSolve<scalar_t>& w,
609 bool partial,
bool isroot,
int depth)
const override;
612 WorkSolve<scalar_t>& w,
613 bool isroot,
int depth)
const override;
616 (WorkExtract<scalar_t>& w,
bool odiag,
int depth)
const override;
618 (DenseM_t& B, WorkExtract<scalar_t>& w,
int depth)
const override;
620 (std::vector<Triplet<scalar_t>>& triplets,
621 WorkExtract<scalar_t>& w,
int depth)
const override;
622 void extract_bwd_internal(WorkExtract<scalar_t>& w,
int depth)
const;
625 (DenseM_t& Theta, DenseM_t& Uop, DenseM_t& Phi,
626 DenseM_t& Vop,
const std::pair<std::size_t, std::size_t>& offset,
627 int depth, std::atomic<long long int>& flops)
const override;
629 (
const DenseM_t& A, DenseM_t& UtA, DenseM_t& VtA,
630 const std::pair<std::size_t, std::size_t>& offset,
631 int depth, std::atomic<long long int>& flops)
const override;
634 (DenseM_t& A, WorkDense<scalar_t>& w,
635 bool isroot,
int depth)
const override;
640 template<
typename T>
friend void apply_HSS
647 template<
typename T>
friend void draw
650 void read(std::ifstream& is)
override;
651 void write(std::ofstream& os)
const override;
663 template<
typename scalar_t>
664 void draw(
const HSSMatrix<scalar_t>& H,
const std::string& name);
677 template<
typename scalar_t>
void apply_HSS
684 #endif // HSS_MATRIX_HPP