HSSMatrix.hpp
Go to the documentation of this file.
1 /*
2  * STRUMPACK -- STRUctured Matrices PACKage, Copyright (c) 2014, The
3  * Regents of the University of California, through Lawrence Berkeley
4  * National Laboratory (subject to receipt of any required approvals
5  * from the U.S. Dept. of Energy). All rights reserved.
6  *
7  * If you have questions about your rights to use or distribute this
8  * software, please contact Berkeley Lab's Technology Transfer
9  * Department at TTD@lbl.gov.
10  *
11  * NOTICE. This software is owned by the U.S. Department of Energy. As
12  * such, the U.S. Government has been granted for itself and others
13  * acting on its behalf a paid-up, nonexclusive, irrevocable,
14  * worldwide license in the Software to reproduce, prepare derivative
15  * works, and perform publicly and display publicly. Beginning five
16  * (5) years after the date permission to assert copyright is obtained
17  * from the U.S. Department of Energy, and subject to any subsequent
18  * five (5) year renewals, the U.S. Government is granted for itself
19  * and others acting on its behalf a paid-up, nonexclusive,
20  * irrevocable, worldwide license in the Software to reproduce,
21  * prepare derivative works, distribute copies to the public, perform
22  * publicly and display publicly, and to permit others to do so.
23  *
24  * Developers: Pieter Ghysels, Francois-Henry Rouet, Xiaoye S. Li.
25  * (Lawrence Berkeley National Lab, Computational Research
26  * Division).
27  *
28  */
37 #ifndef HSS_MATRIX_HPP
38 #define HSS_MATRIX_HPP
39 
40 #include <cassert>
41 #include <functional>
42 #include <string>
43 
44 #include "HSSBasisID.hpp"
45 #include "HSSOptions.hpp"
46 #include "HSSExtra.hpp"
47 #include "HSSMatrixBase.hpp"
48 #include "kernel/Kernel.hpp"
49 #include "HSSMatrix.sketch.hpp"
50 
51 namespace strumpack {
52  namespace HSS {
53 
54 #ifndef DOXYGEN_SHOULD_SKIP_THIS
55  // forward declaration
56  template<typename scalar_t> class HSSMatrixMPI;
57 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
58 
59 
79  template<typename scalar_t> class HSSMatrix
80  : public HSSMatrixBase<scalar_t> {
81  using real_t = typename RealType<scalar_t>::value_type;
84  using elem_t = typename std::function
85  <void(const std::vector<std::size_t>& I,
86  const std::vector<std::size_t>& J, DenseM_t& B)>;
87  using mult_t = typename std::function
88  <void(DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc)>;
90 
91  public:
96 
114  HSSMatrix(const DenseM_t& A, const opts_t& opts);
115 
131  HSSMatrix(std::size_t m, std::size_t n, const opts_t& opts);
132 
143  HSSMatrix(const structured::ClusterTree& t, const opts_t& opts);
144 
155 
162 
169 
174  HSSMatrix(HSSMatrix<scalar_t>&& other) = default;
175 
181 
186  std::unique_ptr<HSSMatrixBase<scalar_t>> clone() const override;
187 
194  const HSSMatrix<scalar_t>* child(int c) const {
195  return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
196  }
197 
205  return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
206  }
207 
224  void compress(const DenseM_t& A, const opts_t& opts);
225 
270  void compress(const std::function<void(DenseM_t& Rr,
271  DenseM_t& Rc,
272  DenseM_t& Sr,
273  DenseM_t& Sc)>& Amult,
274  const std::function<void(const std::vector<std::size_t>& I,
275  const std::vector<std::size_t>& J,
276  DenseM_t& B)>& Aelem,
277  const opts_t& opts);
278 
279 
307  const std::function
308  <void(const std::vector<std::size_t>& I,
309  const std::vector<std::size_t>& J,
310  DenseM_t& B)>& Aelem,
311  const opts_t& opts);
312 
317  void reset() override;
318 
322  void factor() override;
323 
335 
347  void solve(DenseM_t& b) const override;
348 
364  void forward_solve(WorkSolve<scalar_t>& w, const DenseM_t& b,
365  bool partial) const override;
366 
380  void backward_solve(WorkSolve<scalar_t>& w, DenseM_t& x) const override;
381 
390  DenseM_t apply(const DenseM_t& b) const;
391 
404  void mult(Trans op, const DenseM_t& x, DenseM_t& y) const override;
405 
415  DenseM_t applyC(const DenseM_t& b) const;
416 
427  scalar_t get(std::size_t i, std::size_t j) const;
428 
439  DenseM_t extract(const std::vector<std::size_t>& I,
440  const std::vector<std::size_t>& J) const;
441 
455  void extract_add(const std::vector<std::size_t>& I,
456  const std::vector<std::size_t>& J,
457  DenseM_t& B) const;
458 
459 #ifndef DOXYGEN_SHOULD_SKIP_THIS
460  void Schur_update(DenseM_t& Theta,
461  DenseM_t& DUB01,
462  DenseM_t& Phi) const;
463  void Schur_product_direct(const DenseM_t& Theta,
464  const DenseM_t& DUB01,
465  const DenseM_t& Phi,
466  const DenseM_t&_ThetaVhatC_or_VhatCPhiC,
467  const DenseM_t& R,
468  DenseM_t& Sr, DenseM_t& Sc) const;
469  void Schur_product_indirect(const DenseM_t& DUB01,
470  const DenseM_t& R1,
471  const DenseM_t& R2, const DenseM_t& Sr2,
472  const DenseM_t& Sc2,
473  DenseM_t& Sr, DenseM_t& Sc) const;
474  void delete_trailing_block() override;
475 #endif // DOXYGEN_SHOULD_SKIP_THIS
476 
477  std::size_t rank() const override;
478  std::size_t memory() const override;
479  std::size_t nonzeros() const override;
480  std::size_t levels() const override;
481 
482  void print_info(std::ostream& out=std::cout,
483  std::size_t roff=0,
484  std::size_t coff=0) const override;
485 
492  DenseM_t dense() const;
493 
494  void shift(scalar_t sigma) override;
495 
496  void draw(std::ostream& of,
497  std::size_t rlo=0, std::size_t clo=0) const override;
498 
505  void write(const std::string& fname) const;
506 
513  static HSSMatrix<scalar_t> read(const std::string& fname);
514 
515  const HSSFactors<scalar_t>& ULV() { return this->ULV_; }
516 
517  protected:
518  HSSMatrix(std::size_t m, std::size_t n,
519  const opts_t& opts, bool active);
521  const opts_t& opts, bool active);
522  HSSMatrix(std::ifstream& is);
523 
524  HSSBasisID<scalar_t> U_, V_;
525  DenseM_t D_, B01_, B10_;
526 
527  void compress_original(const DenseM_t& A,
528  const opts_t& opts);
529  void compress_original(const mult_t& Amult,
530  const elem_t& Aelem,
531  const opts_t& opts);
532  void compress_stable(const DenseM_t& A,
533  const opts_t& opts);
534  void compress_stable(const mult_t& Amult,
535  const elem_t& Aelem,
536  const opts_t& opts);
537  void compress_hard_restart(const DenseM_t& A,
538  const opts_t& opts);
539  void compress_hard_restart(const mult_t& Amult,
540  const elem_t& Aelem,
541  const opts_t& opts);
542 
543  void compress_recursive_original(DenseM_t& Rr, DenseM_t& Rc,
544  DenseM_t& Sr, DenseM_t& Sc,
545  const elem_t& Aelem,
546  const opts_t& opts,
547  WorkCompress<scalar_t>& w,
548  int dd, int depth) override;
549  void compress_recursive_stable(DenseM_t& Rr, DenseM_t& Rc,
550  DenseM_t& Sr, DenseM_t& Sc,
551  const elem_t& Aelem,
552  const opts_t& opts,
553  WorkCompress<scalar_t>& w,
554  int d, int dd, int depth) override;
555  // SJLT_Matrix<scalar_t,int>* S=nullptr
556  void compute_local_samples(DenseM_t& Rr, DenseM_t& Rc,
557  DenseM_t& Sr, DenseM_t& Sc,
558  WorkCompress<scalar_t>& w,
559  int d0, int d, int depth,
560  SJLTMatrix<scalar_t, int>* S=nullptr);
561  bool compute_U_V_bases(DenseM_t& Sr, DenseM_t& Sc, const opts_t& opts,
562  WorkCompress<scalar_t>& w, int d, int depth);
563  void compute_U_basis_stable(DenseM_t& Sr, const opts_t& opts,
564  WorkCompress<scalar_t>& w,
565  int d, int dd, int depth);
566  void compute_V_basis_stable(DenseM_t& Sc, const opts_t& opts,
567  WorkCompress<scalar_t>& w,
568  int d, int dd, int depth);
569  void reduce_local_samples(DenseM_t& Rr, DenseM_t& Rc,
570  WorkCompress<scalar_t>& w,
571  int d0, int d, int depth);
572  bool update_orthogonal_basis(const opts_t& opts, scalar_t& r_max_0,
573  const DenseM_t& S, DenseM_t& Q,
574  int d, int dd, bool untouched,
575  int L, int depth);
576  void set_U_full_rank(WorkCompress<scalar_t>& w);
577  void set_V_full_rank(WorkCompress<scalar_t>& w);
578 
579  void compress_level_original(DenseM_t& Rr, DenseM_t& Rc,
580  DenseM_t& Sr, DenseM_t& Sc,
581  const opts_t& opts,
582  WorkCompress<scalar_t>& w,
583  int dd, int lvl, int depth) override;
584  void compress_level_stable(DenseM_t& Rr, DenseM_t& Rc,
585  DenseM_t& Sr, DenseM_t& Sc,
586  const opts_t& opts,
587  WorkCompress<scalar_t>& w,
588  int d, int dd, int lvl, int depth) override;
589  void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
590  std::vector<std::vector<std::size_t>>& J,
591  const std::pair<std::size_t,std::size_t>& off,
592  WorkCompress<scalar_t>& w,
593  int& self, int lvl) override;
594  void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
595  std::vector<std::vector<std::size_t>>& J,
596  std::vector<DenseM_t*>& B,
597  const std::pair<std::size_t,std::size_t>& off,
598  WorkCompress<scalar_t>& w,
599  int& self, int lvl) override;
600  void extract_D_B(const elem_t& Aelem, const opts_t& opts,
601  WorkCompress<scalar_t>& w, int lvl) override;
602 
603  void compress(const kernel::Kernel<real_t>& K, const opts_t& opts);
604  void compress_recursive_ann(DenseMatrix<std::uint32_t>& ann,
605  DenseMatrix<real_t>& scores,
606  const elem_t& Aelem, const opts_t& opts,
607  WorkCompressANN<scalar_t>& w,
608  int depth) override;
609  void compute_local_samples_ann(DenseMatrix<std::uint32_t>& ann,
610  DenseMatrix<real_t>& scores,
611  WorkCompressANN<scalar_t>& w,
612  const elem_t& Aelem, const opts_t& opts);
613  bool compute_U_V_bases_ann(DenseM_t& S, const opts_t& opts,
614  WorkCompressANN<scalar_t>& w, int depth);
615 
616  void factor_recursive(WorkFactor<scalar_t>& w,
617  bool isroot, bool partial,
618  int depth) override;
619 
620  void apply_fwd(const DenseM_t& b, WorkApply<scalar_t>& w,
621  bool isroot, int depth,
622  std::atomic<long long int>& flops) const override;
623  void apply_bwd(const DenseM_t& b, scalar_t beta, DenseM_t& c,
624  WorkApply<scalar_t>& w, bool isroot, int depth,
625  std::atomic<long long int>& flops) const override;
626  void applyT_fwd(const DenseM_t& b, WorkApply<scalar_t>& w, bool isroot,
627  int depth, std::atomic<long long int>& flops) const override;
628  void applyT_bwd(const DenseM_t& b, scalar_t beta, DenseM_t& c,
629  WorkApply<scalar_t>& w, bool isroot, int depth,
630  std::atomic<long long int>& flops) const override;
631 
632  void solve_fwd(const DenseM_t& b, WorkSolve<scalar_t>& w,
633  bool partial, bool isroot, int depth) const override;
634  void solve_bwd(DenseM_t& x, WorkSolve<scalar_t>& w,
635  bool isroot, int depth) const override;
636 
637  void extract_fwd(WorkExtract<scalar_t>& w,
638  bool odiag, int depth) const override;
639  void extract_bwd(DenseM_t& B, WorkExtract<scalar_t>& w,
640  int depth) const override;
641  void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
642  WorkExtract<scalar_t>& w, int depth) const override;
643  void extract_bwd_internal(WorkExtract<scalar_t>& w, int depth) const;
644 
645  void apply_UV_big(DenseM_t& Theta, DenseM_t& Uop, DenseM_t& Phi,
646  DenseM_t& Vop,
647  const std::pair<std::size_t, std::size_t>& offset,
648  int depth, std::atomic<long long int>& flops)
649  const override;
650  void apply_UtVt_big(const DenseM_t& A, DenseM_t& UtA, DenseM_t& VtA,
651  const std::pair<std::size_t, std::size_t>& offset,
652  int depth, std::atomic<long long int>& flops)
653  const override;
654 
655  void dense_recursive(DenseM_t& A, WorkDense<scalar_t>& w,
656  bool isroot, int depth) const override;
657 
661  template<typename T> friend
662  void apply_HSS(Trans ta, const HSSMatrix<T>& a, const DenseMatrix<T>& b,
663  T beta, DenseMatrix<T>& c);
664 
668  template<typename T> friend
669  void draw(const HSSMatrix<T>& H, const std::string& name);
670 
671  void read(std::ifstream& is) override;
672  void write(std::ofstream& os) const override;
673 
674  friend class HSSMatrixMPI<scalar_t>;
675 
677  };
678 
686  template<typename scalar_t>
687  void draw(const HSSMatrix<scalar_t>& H, const std::string& name);
688 
700  template<typename scalar_t> void
702  const DenseMatrix<scalar_t>& B,
703  scalar_t beta, DenseMatrix<scalar_t>& C);
704 
705  } // end namespace HSS
706 } // end namespace strumpack
707 
708 #endif // HSS_MATRIX_HPP
This file contains the HSSMatrixBase class definition, an abstract class for HSS matrix representatio...
Contains the HSSOptions class as well as general routines for HSS options.
Definitions of several kernel functions, and helper routines. Also provides driver routines for kerne...
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition: DenseMatrix.hpp:1015
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition: DenseMatrix.hpp:138
Contains data related to ULV factorization of an HSS matrix.
Definition: HSSExtra.hpp:161
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition: HSSMatrixBase.hpp:83
bool active() const
Definition: HSSMatrixBase.hpp:239
Distributed memory implementation of the HSS (Hierarchically Semi-Separable) matrix format.
Definition: HSSMatrixMPI.hpp:68
Class to represent a sequential/threaded Hierarchically Semi-Separable matrix.
Definition: HSSMatrix.hpp:80
void factor() override
void forward_solve(WorkSolve< scalar_t > &w, const DenseM_t &b, bool partial) const override
std::size_t levels() const override
void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const override
const HSSMatrix< scalar_t > * child(int c) const
Definition: HSSMatrix.hpp:194
HSSMatrix(const DenseM_t &A, const opts_t &opts)
friend void draw(const HSSMatrix< T > &H, const std::string &name)
void draw(std::ostream &of, std::size_t rlo=0, std::size_t clo=0) const override
void extract_add(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B) const
scalar_t get(std::size_t i, std::size_t j) const
void shift(scalar_t sigma) override
void compress_with_coordinates(const DenseMatrix< real_t > &coords, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const override
void compress(const std::function< void(DenseM_t &Rr, DenseM_t &Rc, DenseM_t &Sr, DenseM_t &Sc)> &Amult, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
std::size_t memory() const override
DenseM_t applyC(const DenseM_t &b) const
HSSMatrix(HSSMatrix< scalar_t > &&other)=default
void write(const std::string &fname) const
static HSSMatrix< scalar_t > read(const std::string &fname)
HSSMatrix(const HSSMatrix< scalar_t > &other)
HSSMatrix(kernel::Kernel< real_t > &K, const opts_t &opts)
HSSMatrix< scalar_t > * child(int c)
Definition: HSSMatrix.hpp:204
void backward_solve(WorkSolve< scalar_t > &w, DenseM_t &x) const override
friend void apply_HSS(Trans ta, const HSSMatrix< T > &a, const DenseMatrix< T > &b, T beta, DenseMatrix< T > &c)
HSSMatrix< scalar_t > & operator=(const HSSMatrix< scalar_t > &other)
DenseM_t extract(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J) const
void solve(DenseM_t &b) const override
void compress(const DenseM_t &A, const opts_t &opts)
HSSMatrix(std::size_t m, std::size_t n, const opts_t &opts)
std::size_t rank() const override
HSSMatrix(const structured::ClusterTree &t, const opts_t &opts)
DenseM_t dense() const
void mult(Trans op, const DenseM_t &x, DenseM_t &y) const override
HSSMatrix< scalar_t > & operator=(HSSMatrix< scalar_t > &&other)=default
std::size_t nonzeros() const override
DenseM_t apply(const DenseM_t &b) const
Class containing several options for the HSS code and data-structures.
Definition: HSSOptions.hpp:152
Representation of a kernel matrix.
Definition: Kernel.hpp:73
The cluster tree, or partition tree that represents the partitioning of the rows or columns of a hier...
Definition: ClusterTree.hpp:67
void apply_HSS(Trans op, const HSSMatrix< scalar_t > &A, const DenseMatrix< scalar_t > &B, scalar_t beta, DenseMatrix< scalar_t > &C)
void draw(const HSSMatrix< scalar_t > &H, const std::string &name)
Definition: StrumpackOptions.hpp:43
Trans
Definition: DenseMatrix.hpp:51