HSSMatrix.hpp
Go to the documentation of this file.
1 /*
2  * STRUMPACK -- STRUctured Matrices PACKage, Copyright (c) 2014, The
3  * Regents of the University of California, through Lawrence Berkeley
4  * National Laboratory (subject to receipt of any required approvals
5  * from the U.S. Dept. of Energy). All rights reserved.
6  *
7  * If you have questions about your rights to use or distribute this
8  * software, please contact Berkeley Lab's Technology Transfer
9  * Department at TTD@lbl.gov.
10  *
11  * NOTICE. This software is owned by the U.S. Department of Energy. As
12  * such, the U.S. Government has been granted for itself and others
13  * acting on its behalf a paid-up, nonexclusive, irrevocable,
14  * worldwide license in the Software to reproduce, prepare derivative
15  * works, and perform publicly and display publicly. Beginning five
16  * (5) years after the date permission to assert copyright is obtained
17  * from the U.S. Department of Energy, and subject to any subsequent
18  * five (5) year renewals, the U.S. Government is granted for itself
19  * and others acting on its behalf a paid-up, nonexclusive,
20  * irrevocable, worldwide license in the Software to reproduce,
21  * prepare derivative works, distribute copies to the public, perform
22  * publicly and display publicly, and to permit others to do so.
23  *
24  * Developers: Pieter Ghysels, Francois-Henry Rouet, Xiaoye S. Li.
25  * (Lawrence Berkeley National Lab, Computational Research
26  * Division).
27  *
28  */
37 #ifndef HSS_MATRIX_HPP
38 #define HSS_MATRIX_HPP
39 
40 #include <cassert>
41 #include <functional>
42 #include <string>
43 
44 #include "HSSBasisID.hpp"
45 #include "HSSOptions.hpp"
46 #include "HSSExtra.hpp"
47 #include "HSSMatrixBase.hpp"
48 #include "kernel/Kernel.hpp"
49 
50 namespace strumpack {
51  namespace HSS {
52 
53 #ifndef DOXYGEN_SHOULD_SKIP_THIS
54  // forward declaration
55  template<typename scalar_t> class HSSMatrixMPI;
56 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
57 
58 
78  template<typename scalar_t> class HSSMatrix
79  : public HSSMatrixBase<scalar_t> {
80  using real_t = typename RealType<scalar_t>::value_type;
83  using elem_t = typename std::function
84  <void(const std::vector<std::size_t>& I,
85  const std::vector<std::size_t>& J, DenseM_t& B)>;
86  using mult_t = typename std::function
87  <void(DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc)>;
89 
90  public:
95 
113  HSSMatrix(const DenseM_t& A, const opts_t& opts);
114 
130  HSSMatrix(std::size_t m, std::size_t n, const opts_t& opts);
131 
142  HSSMatrix(const structured::ClusterTree& t, const opts_t& opts);
143 
154 
161 
168 
173  HSSMatrix(HSSMatrix<scalar_t>&& other) = default;
174 
180 
185  std::unique_ptr<HSSMatrixBase<scalar_t>> clone() const override;
186 
193  const HSSMatrix<scalar_t>* child(int c) const {
194  return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
195  }
196 
204  return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
205  }
206 
223  void compress(const DenseM_t& A, const opts_t& opts);
224 
269  void compress(const std::function<void(DenseM_t& Rr,
270  DenseM_t& Rc,
271  DenseM_t& Sr,
272  DenseM_t& Sc)>& Amult,
273  const std::function<void(const std::vector<std::size_t>& I,
274  const std::vector<std::size_t>& J,
275  DenseM_t& B)>& Aelem,
276  const opts_t& opts);
277 
278 
306  const std::function
307  <void(const std::vector<std::size_t>& I,
308  const std::vector<std::size_t>& J,
309  DenseM_t& B)>& Aelem,
310  const opts_t& opts);
311 
316  void reset() override;
317 
321  void factor() override;
322 
334 
346  void solve(DenseM_t& b) const override;
347 
363  void forward_solve(WorkSolve<scalar_t>& w, const DenseM_t& b,
364  bool partial) const override;
365 
379  void backward_solve(WorkSolve<scalar_t>& w, DenseM_t& x) const override;
380 
389  DenseM_t apply(const DenseM_t& b) const;
390 
403  void mult(Trans op, const DenseM_t& x, DenseM_t& y) const override;
404 
414  DenseM_t applyC(const DenseM_t& b) const;
415 
426  scalar_t get(std::size_t i, std::size_t j) const;
427 
438  DenseM_t extract(const std::vector<std::size_t>& I,
439  const std::vector<std::size_t>& J) const;
440 
454  void extract_add(const std::vector<std::size_t>& I,
455  const std::vector<std::size_t>& J,
456  DenseM_t& B) const;
457 
458 #ifndef DOXYGEN_SHOULD_SKIP_THIS
459  void Schur_update(DenseM_t& Theta,
460  DenseM_t& DUB01,
461  DenseM_t& Phi) const;
462  void Schur_product_direct(const DenseM_t& Theta,
463  const DenseM_t& DUB01,
464  const DenseM_t& Phi,
465  const DenseM_t&_ThetaVhatC_or_VhatCPhiC,
466  const DenseM_t& R,
467  DenseM_t& Sr, DenseM_t& Sc) const;
468  void Schur_product_indirect(const DenseM_t& DUB01,
469  const DenseM_t& R1,
470  const DenseM_t& R2, const DenseM_t& Sr2,
471  const DenseM_t& Sc2,
472  DenseM_t& Sr, DenseM_t& Sc) const;
473  void delete_trailing_block() override;
474 #endif // DOXYGEN_SHOULD_SKIP_THIS
475 
476  std::size_t rank() const override;
477  std::size_t memory() const override;
478  std::size_t nonzeros() const override;
479  std::size_t levels() const override;
480 
481  void print_info(std::ostream& out=std::cout,
482  std::size_t roff=0,
483  std::size_t coff=0) const override;
484 
491  DenseM_t dense() const;
492 
493  void shift(scalar_t sigma) override;
494 
495  void draw(std::ostream& of,
496  std::size_t rlo=0, std::size_t clo=0) const override;
497 
504  void write(const std::string& fname) const;
505 
512  static HSSMatrix<scalar_t> read(const std::string& fname);
513 
514  const HSSFactors<scalar_t>& ULV() { return this->ULV_; }
515 
516  protected:
517  HSSMatrix(std::size_t m, std::size_t n,
518  const opts_t& opts, bool active);
520  const opts_t& opts, bool active);
521  HSSMatrix(std::ifstream& is);
522 
523  HSSBasisID<scalar_t> U_, V_;
524  DenseM_t D_, B01_, B10_;
525 
526  void compress_original(const DenseM_t& A,
527  const opts_t& opts);
528  void compress_original(const mult_t& Amult,
529  const elem_t& Aelem,
530  const opts_t& opts);
531  void compress_stable(const DenseM_t& A,
532  const opts_t& opts);
533  void compress_stable(const mult_t& Amult,
534  const elem_t& Aelem,
535  const opts_t& opts);
536  void compress_hard_restart(const DenseM_t& A,
537  const opts_t& opts);
538  void compress_hard_restart(const mult_t& Amult,
539  const elem_t& Aelem,
540  const opts_t& opts);
541 
542  void compress_recursive_original(DenseM_t& Rr, DenseM_t& Rc,
543  DenseM_t& Sr, DenseM_t& Sc,
544  const elem_t& Aelem,
545  const opts_t& opts,
546  WorkCompress<scalar_t>& w,
547  int dd, int depth) override;
548  void compress_recursive_stable(DenseM_t& Rr, DenseM_t& Rc,
549  DenseM_t& Sr, DenseM_t& Sc,
550  const elem_t& Aelem,
551  const opts_t& opts,
552  WorkCompress<scalar_t>& w,
553  int d, int dd, int depth) override;
554  void compute_local_samples(DenseM_t& Rr, DenseM_t& Rc,
555  DenseM_t& Sr, DenseM_t& Sc,
556  WorkCompress<scalar_t>& w,
557  int d0, int d, int depth);
558  bool compute_U_V_bases(DenseM_t& Sr, DenseM_t& Sc, const opts_t& opts,
559  WorkCompress<scalar_t>& w, int d, int depth);
560  void compute_U_basis_stable(DenseM_t& Sr, const opts_t& opts,
561  WorkCompress<scalar_t>& w,
562  int d, int dd, int depth);
563  void compute_V_basis_stable(DenseM_t& Sc, const opts_t& opts,
564  WorkCompress<scalar_t>& w,
565  int d, int dd, int depth);
566  void reduce_local_samples(DenseM_t& Rr, DenseM_t& Rc,
567  WorkCompress<scalar_t>& w,
568  int d0, int d, int depth);
569  bool update_orthogonal_basis(const opts_t& opts, scalar_t& r_max_0,
570  const DenseM_t& S, DenseM_t& Q,
571  int d, int dd, bool untouched,
572  int L, int depth);
573  void set_U_full_rank(WorkCompress<scalar_t>& w);
574  void set_V_full_rank(WorkCompress<scalar_t>& w);
575 
576  void compress_level_original(DenseM_t& Rr, DenseM_t& Rc,
577  DenseM_t& Sr, DenseM_t& Sc,
578  const opts_t& opts,
579  WorkCompress<scalar_t>& w,
580  int dd, int lvl, int depth) override;
581  void compress_level_stable(DenseM_t& Rr, DenseM_t& Rc,
582  DenseM_t& Sr, DenseM_t& Sc,
583  const opts_t& opts,
584  WorkCompress<scalar_t>& w,
585  int d, int dd, int lvl, int depth) override;
586  void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
587  std::vector<std::vector<std::size_t>>& J,
588  const std::pair<std::size_t,std::size_t>& off,
589  WorkCompress<scalar_t>& w,
590  int& self, int lvl) override;
591  void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
592  std::vector<std::vector<std::size_t>>& J,
593  std::vector<DenseM_t*>& B,
594  const std::pair<std::size_t,std::size_t>& off,
595  WorkCompress<scalar_t>& w,
596  int& self, int lvl) override;
597  void extract_D_B(const elem_t& Aelem, const opts_t& opts,
598  WorkCompress<scalar_t>& w, int lvl) override;
599 
600  void compress(const kernel::Kernel<real_t>& K, const opts_t& opts);
601  void compress_recursive_ann(DenseMatrix<std::uint32_t>& ann,
602  DenseMatrix<real_t>& scores,
603  const elem_t& Aelem, const opts_t& opts,
604  WorkCompressANN<scalar_t>& w,
605  int depth) override;
606  void compute_local_samples_ann(DenseMatrix<std::uint32_t>& ann,
607  DenseMatrix<real_t>& scores,
608  WorkCompressANN<scalar_t>& w,
609  const elem_t& Aelem, const opts_t& opts);
610  bool compute_U_V_bases_ann(DenseM_t& S, const opts_t& opts,
611  WorkCompressANN<scalar_t>& w, int depth);
612 
613  void factor_recursive(WorkFactor<scalar_t>& w,
614  bool isroot, bool partial,
615  int depth) override;
616 
617  void apply_fwd(const DenseM_t& b, WorkApply<scalar_t>& w,
618  bool isroot, int depth,
619  std::atomic<long long int>& flops) const override;
620  void apply_bwd(const DenseM_t& b, scalar_t beta, DenseM_t& c,
621  WorkApply<scalar_t>& w, bool isroot, int depth,
622  std::atomic<long long int>& flops) const override;
623  void applyT_fwd(const DenseM_t& b, WorkApply<scalar_t>& w, bool isroot,
624  int depth, std::atomic<long long int>& flops) const override;
625  void applyT_bwd(const DenseM_t& b, scalar_t beta, DenseM_t& c,
626  WorkApply<scalar_t>& w, bool isroot, int depth,
627  std::atomic<long long int>& flops) const override;
628 
629  void solve_fwd(const DenseM_t& b, WorkSolve<scalar_t>& w,
630  bool partial, bool isroot, int depth) const override;
631  void solve_bwd(DenseM_t& x, WorkSolve<scalar_t>& w,
632  bool isroot, int depth) const override;
633 
634  void extract_fwd(WorkExtract<scalar_t>& w,
635  bool odiag, int depth) const override;
636  void extract_bwd(DenseM_t& B, WorkExtract<scalar_t>& w,
637  int depth) const override;
638  void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
639  WorkExtract<scalar_t>& w, int depth) const override;
640  void extract_bwd_internal(WorkExtract<scalar_t>& w, int depth) const;
641 
642  void apply_UV_big(DenseM_t& Theta, DenseM_t& Uop, DenseM_t& Phi,
643  DenseM_t& Vop,
644  const std::pair<std::size_t, std::size_t>& offset,
645  int depth, std::atomic<long long int>& flops)
646  const override;
647  void apply_UtVt_big(const DenseM_t& A, DenseM_t& UtA, DenseM_t& VtA,
648  const std::pair<std::size_t, std::size_t>& offset,
649  int depth, std::atomic<long long int>& flops)
650  const override;
651 
652  void dense_recursive(DenseM_t& A, WorkDense<scalar_t>& w,
653  bool isroot, int depth) const override;
654 
658  template<typename T> friend
659  void apply_HSS(Trans ta, const HSSMatrix<T>& a, const DenseMatrix<T>& b,
660  T beta, DenseMatrix<T>& c);
661 
665  template<typename T> friend
666  void draw(const HSSMatrix<T>& H, const std::string& name);
667 
668  void read(std::ifstream& is) override;
669  void write(std::ofstream& os) const override;
670 
671  friend class HSSMatrixMPI<scalar_t>;
672 
674  };
675 
683  template<typename scalar_t>
684  void draw(const HSSMatrix<scalar_t>& H, const std::string& name);
685 
697  template<typename scalar_t> void
699  const DenseMatrix<scalar_t>& B,
700  scalar_t beta, DenseMatrix<scalar_t>& C);
701 
702  } // end namespace HSS
703 } // end namespace strumpack
704 
705 #endif // HSS_MATRIX_HPP
This file contains the HSSMatrixBase class definition, an abstract class for HSS matrix representatio...
Contains the HSSOptions class as well as general routines for HSS options.
Definitions of several kernel functions, and helper routines. Also provides driver routines for kerne...
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition: DenseMatrix.hpp:1015
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition: DenseMatrix.hpp:138
Contains data related to ULV factorization of an HSS matrix.
Definition: HSSExtra.hpp:161
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition: HSSMatrixBase.hpp:83
bool active() const
Definition: HSSMatrixBase.hpp:239
Distributed memory implementation of the HSS (Hierarchically Semi-Separable) matrix format.
Definition: HSSMatrixMPI.hpp:68
Class to represent a sequential/threaded Hierarchically Semi-Separable matrix.
Definition: HSSMatrix.hpp:79
void factor() override
void forward_solve(WorkSolve< scalar_t > &w, const DenseM_t &b, bool partial) const override
std::size_t levels() const override
void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const override
const HSSMatrix< scalar_t > * child(int c) const
Definition: HSSMatrix.hpp:193
HSSMatrix(const DenseM_t &A, const opts_t &opts)
friend void draw(const HSSMatrix< T > &H, const std::string &name)
void draw(std::ostream &of, std::size_t rlo=0, std::size_t clo=0) const override
void extract_add(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B) const
scalar_t get(std::size_t i, std::size_t j) const
void shift(scalar_t sigma) override
void compress_with_coordinates(const DenseMatrix< real_t > &coords, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const override
void compress(const std::function< void(DenseM_t &Rr, DenseM_t &Rc, DenseM_t &Sr, DenseM_t &Sc)> &Amult, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
std::size_t memory() const override
DenseM_t applyC(const DenseM_t &b) const
HSSMatrix(HSSMatrix< scalar_t > &&other)=default
void write(const std::string &fname) const
static HSSMatrix< scalar_t > read(const std::string &fname)
HSSMatrix(const HSSMatrix< scalar_t > &other)
HSSMatrix(kernel::Kernel< real_t > &K, const opts_t &opts)
HSSMatrix< scalar_t > * child(int c)
Definition: HSSMatrix.hpp:203
void backward_solve(WorkSolve< scalar_t > &w, DenseM_t &x) const override
friend void apply_HSS(Trans ta, const HSSMatrix< T > &a, const DenseMatrix< T > &b, T beta, DenseMatrix< T > &c)
HSSMatrix< scalar_t > & operator=(const HSSMatrix< scalar_t > &other)
DenseM_t extract(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J) const
void solve(DenseM_t &b) const override
void compress(const DenseM_t &A, const opts_t &opts)
HSSMatrix(std::size_t m, std::size_t n, const opts_t &opts)
std::size_t rank() const override
HSSMatrix(const structured::ClusterTree &t, const opts_t &opts)
DenseM_t dense() const
void mult(Trans op, const DenseM_t &x, DenseM_t &y) const override
HSSMatrix< scalar_t > & operator=(HSSMatrix< scalar_t > &&other)=default
std::size_t nonzeros() const override
DenseM_t apply(const DenseM_t &b) const
Class containing several options for the HSS code and data-structures.
Definition: HSSOptions.hpp:118
Representation of a kernel matrix.
Definition: Kernel.hpp:73
The cluster tree, or partition tree that represents the partitioning of the rows or columns of a hier...
Definition: ClusterTree.hpp:62
void apply_HSS(Trans op, const HSSMatrix< scalar_t > &A, const DenseMatrix< scalar_t > &B, scalar_t beta, DenseMatrix< scalar_t > &C)
void draw(const HSSMatrix< scalar_t > &H, const std::string &name)
Definition: StrumpackOptions.hpp:42
Trans
Definition: DenseMatrix.hpp:51