81 using real_t =
typename RealType<scalar_t>::value_type;
84 using elem_t =
typename std::function
85 <void(
const std::vector<std::size_t>& I,
86 const std::vector<std::size_t>& J,
DenseM_t& B)>;
87 using mult_t =
typename std::function
186 std::unique_ptr<HSSMatrixBase<scalar_t>>
clone()
const override;
274 const std::function<
void(
const std::vector<std::size_t>& I,
275 const std::vector<std::size_t>& J,
308 <
void(
const std::vector<std::size_t>& I,
309 const std::vector<std::size_t>& J,
365 bool partial)
const override;
427 scalar_t
get(std::size_t i, std::size_t j)
const;
440 const std::vector<std::size_t>& J)
const;
456 const std::vector<std::size_t>& J,
459#ifndef DOXYGEN_SHOULD_SKIP_THIS
463 void Schur_product_direct(
const DenseM_t& Theta,
466 const DenseM_t&_ThetaVhatC_or_VhatCPhiC,
469 void Schur_product_indirect(
const DenseM_t& DUB01,
474 void delete_trailing_block()
override;
477 std::size_t
rank()
const override;
484 std::size_t coff=0)
const override;
494 void shift(scalar_t sigma)
override;
497 std::size_t rlo=0, std::size_t clo=0)
const override;
505 void write(
const std::string& fname)
const;
519 const opts_t& opts,
bool active);
521 const opts_t& opts,
bool active);
524 HSSBasisID<scalar_t> U_, V_;
525 DenseM_t D_, B01_, B10_;
527 void compress_original(
const DenseM_t& A,
529 void compress_original(
const mult_t& Amult,
532 void compress_stable(
const DenseM_t& A,
534 void compress_stable(
const mult_t& Amult,
537 void compress_hard_restart(
const DenseM_t& A,
539 void compress_hard_restart(
const mult_t& Amult,
543 void compress_recursive_original(DenseM_t& Rr, DenseM_t& Rc,
544 DenseM_t& Sr, DenseM_t& Sc,
547 WorkCompress<scalar_t>& w,
548 int dd,
int depth)
override;
549 void compress_recursive_stable(DenseM_t& Rr, DenseM_t& Rc,
550 DenseM_t& Sr, DenseM_t& Sc,
553 WorkCompress<scalar_t>& w,
554 int d,
int dd,
int depth)
override;
555 void compute_local_samples(DenseM_t& Rr, DenseM_t& Rc,
556 DenseM_t& Sr, DenseM_t& Sc,
557 WorkCompress<scalar_t>& w,
558 int d0,
int d,
int depth,
559 SJLTMatrix<scalar_t, int>* S=
nullptr);
560 bool compute_U_V_bases(DenseM_t& Sr, DenseM_t& Sc,
const opts_t& opts,
561 WorkCompress<scalar_t>& w,
int d,
int depth);
562 void compute_U_basis_stable(DenseM_t& Sr,
const opts_t& opts,
563 WorkCompress<scalar_t>& w,
564 int d,
int dd,
int depth);
565 void compute_V_basis_stable(DenseM_t& Sc,
const opts_t& opts,
566 WorkCompress<scalar_t>& w,
567 int d,
int dd,
int depth);
568 void reduce_local_samples(DenseM_t& Rr, DenseM_t& Rc,
569 WorkCompress<scalar_t>& w,
570 int d0,
int d,
int depth);
571 bool update_orthogonal_basis(
const opts_t& opts, scalar_t& r_max_0,
572 const DenseM_t& S, DenseM_t& Q,
573 int d,
int dd,
bool untouched,
575 void set_U_full_rank(WorkCompress<scalar_t>& w);
576 void set_V_full_rank(WorkCompress<scalar_t>& w);
578 void compress_level_original(DenseM_t& Rr, DenseM_t& Rc,
579 DenseM_t& Sr, DenseM_t& Sc,
581 WorkCompress<scalar_t>& w,
582 int dd,
int lvl,
int depth)
override;
583 void compress_level_stable(DenseM_t& Rr, DenseM_t& Rc,
584 DenseM_t& Sr, DenseM_t& Sc,
586 WorkCompress<scalar_t>& w,
587 int d,
int dd,
int lvl,
int depth)
override;
589 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
590 std::vector<std::vector<std::size_t>>& J,
591 const std::pair<std::size_t,std::size_t>& off,
592 WorkCompress<scalar_t>& w,
593 int& self,
int lvl)
override;
594 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
595 std::vector<std::vector<std::size_t>>& J,
596 std::vector<DenseM_t*>& B,
597 const std::pair<std::size_t,std::size_t>& off,
598 WorkCompress<scalar_t>& w,
599 int& self,
int lvl)
override;
600 void extract_D_B(
const elem_t& Aelem,
const opts_t& opts,
601 WorkCompress<scalar_t>& w,
int lvl)
override;
606 const elem_t& Aelem,
const opts_t& opts,
607 WorkCompressANN<scalar_t>& w,
611 WorkCompressANN<scalar_t>& w,
612 const elem_t& Aelem,
const opts_t& opts);
613 bool compute_U_V_bases_ann(DenseM_t& S,
const opts_t& opts,
614 WorkCompressANN<scalar_t>& w,
int depth);
616 void factor_recursive(WorkFactor<scalar_t>& w,
617 bool isroot,
bool partial,
620 void apply_fwd(
const DenseM_t& b, WorkApply<scalar_t>& w,
621 bool isroot,
int depth,
622 std::atomic<long long int>& flops)
const override;
623 void apply_bwd(
const DenseM_t& b, scalar_t beta, DenseM_t& c,
624 WorkApply<scalar_t>& w,
bool isroot,
int depth,
625 std::atomic<long long int>& flops)
const override;
626 void applyT_fwd(
const DenseM_t& b, WorkApply<scalar_t>& w,
bool isroot,
627 int depth, std::atomic<long long int>& flops)
const override;
628 void applyT_bwd(
const DenseM_t& b, scalar_t beta, DenseM_t& c,
629 WorkApply<scalar_t>& w,
bool isroot,
int depth,
630 std::atomic<long long int>& flops)
const override;
632 void solve_fwd(
const DenseM_t& b, WorkSolve<scalar_t>& w,
633 bool partial,
bool isroot,
int depth)
const override;
634 void solve_bwd(DenseM_t& x, WorkSolve<scalar_t>& w,
635 bool isroot,
int depth)
const override;
637 void extract_fwd(WorkExtract<scalar_t>& w,
638 bool odiag,
int depth)
const override;
639 void extract_bwd(DenseM_t& B, WorkExtract<scalar_t>& w,
640 int depth)
const override;
641 void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
642 WorkExtract<scalar_t>& w,
int depth)
const override;
643 void extract_bwd_internal(WorkExtract<scalar_t>& w,
int depth)
const;
645 void apply_UV_big(DenseM_t& Theta, DenseM_t& Uop, DenseM_t& Phi,
647 const std::pair<std::size_t, std::size_t>& offset,
648 int depth, std::atomic<long long int>& flops)
650 void apply_UtVt_big(
const DenseM_t& A, DenseM_t& UtA, DenseM_t& VtA,
651 const std::pair<std::size_t, std::size_t>& offset,
652 int depth, std::atomic<long long int>& flops)
655 void dense_recursive(DenseM_t& A, WorkDense<scalar_t>& w,
656 bool isroot,
int depth)
const override;
661 template<
typename T>
friend
668 template<
typename T>
friend
671 void read(std::ifstream& is)
override;
672 void write(std::ofstream& os)
const override;