HSSMatrix.hpp
Go to the documentation of this file.
1 /*
2  * STRUMPACK -- STRUctured Matrices PACKage, Copyright (c) 2014, The
3  * Regents of the University of California, through Lawrence Berkeley
4  * National Laboratory (subject to receipt of any required approvals
5  * from the U.S. Dept. of Energy). All rights reserved.
6  *
7  * If you have questions about your rights to use or distribute this
8  * software, please contact Berkeley Lab's Technology Transfer
9  * Department at TTD@lbl.gov.
10  *
11  * NOTICE. This software is owned by the U.S. Department of Energy. As
12  * such, the U.S. Government has been granted for itself and others
13  * acting on its behalf a paid-up, nonexclusive, irrevocable,
14  * worldwide license in the Software to reproduce, prepare derivative
15  * works, and perform publicly and display publicly. Beginning five
16  * (5) years after the date permission to assert copyright is obtained
17  * from the U.S. Department of Energy, and subject to any subsequent
18  * five (5) year renewals, the U.S. Government is granted for itself
19  * and others acting on its behalf a paid-up, nonexclusive,
20  * irrevocable, worldwide license in the Software to reproduce,
21  * prepare derivative works, distribute copies to the public, perform
22  * publicly and display publicly, and to permit others to do so.
23  *
24  * Developers: Pieter Ghysels, Francois-Henry Rouet, Xiaoye S. Li.
25  * (Lawrence Berkeley National Lab, Computational Research
26  * Division).
27  *
28  */
37 #ifndef HSS_MATRIX_HPP
38 #define HSS_MATRIX_HPP
39 
40 #include <cassert>
41 #include <functional>
42 
43 #include "HSSPartitionTree.hpp"
44 #include "HSSBasisID.hpp"
45 #include "HSSOptions.hpp"
46 #include "HSSExtra.hpp"
47 #include "HSSMatrixBase.hpp"
48 #include "kernel/Kernel.hpp"
49 
50 namespace strumpack {
51  namespace HSS {
52 
53  // forward declaration
54  template<typename scalar_t> class HSSMatrixMPI;
55 
75  template<typename scalar_t> class HSSMatrix
76  : public HSSMatrixBase<scalar_t> {
77  using real_t = typename RealType<scalar_t>::value_type;
80  using elem_t = typename std::function
81  <void(const std::vector<std::size_t>& I,
82  const std::vector<std::size_t>& J, DenseM_t& B)>;
83  using mult_t = typename std::function
84  <void(DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc)>;
86 
87  public:
91  HSSMatrix();
92 
110  HSSMatrix(const DenseM_t& A, const opts_t& opts);
111 
127  HSSMatrix(std::size_t m, std::size_t n, const opts_t& opts);
128 
139  HSSMatrix(const HSSPartitionTree& t, const opts_t& opts);
140 
150  HSSMatrix
151  (kernel::Kernel<real_t>& K, const opts_t& opts);
152 
153 
159  HSSMatrix(const HSSMatrix<scalar_t>& other);
160 
167 
172  HSSMatrix(HSSMatrix<scalar_t>&& other) = default;
173 
179 
184  std::unique_ptr<HSSMatrixBase<scalar_t>> clone() const override;
185 
192  const HSSMatrix<scalar_t>* child(int c) const {
193  return dynamic_cast<HSSMatrix<scalar_t>*>(this->_ch[c].get());
194  }
195 
203  return dynamic_cast<HSSMatrix<scalar_t>*>(this->_ch[c].get());
204  }
205 
222  void compress(const DenseM_t& A, const opts_t& opts);
223 
268  void compress
269  (const std::function
270  <void(DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc)>& Amult,
271  const std::function
272  <void(const std::vector<std::size_t>& I,
273  const std::vector<std::size_t>& J, DenseM_t& B)>& Aelem,
274  const opts_t& opts);
275 
276 
304  (const DenseMatrix<real_t>& coords,
305  const std::function
306  <void(const std::vector<std::size_t>& I,
307  const std::vector<std::size_t>& J, DenseM_t& B)>& Aelem,
308  const opts_t& opts);
309 
314  void reset() override;
315 
322 
335 
347  void solve(const HSSFactors<scalar_t>& ULV, DenseM_t& b) const;
348 
365  void forward_solve
366  (const HSSFactors<scalar_t>& ULV, WorkSolve<scalar_t>& w,
367  const DenseM_t& b, bool partial) const override;
368 
383  void backward_solve
384  (const HSSFactors<scalar_t>& ULV,
385  WorkSolve<scalar_t>& w, DenseM_t& x) const override;
386 
395  DenseM_t apply(const DenseM_t& b) const;
396 
406  DenseM_t applyC(const DenseM_t& b) const;
407 
418  scalar_t get(std::size_t i, std::size_t j) const;
419 
430  DenseM_t extract
431  (const std::vector<std::size_t>& I,
432  const std::vector<std::size_t>& J) const;
433 
447  void extract_add
448  (const std::vector<std::size_t>& I, const std::vector<std::size_t>& J,
449  DenseM_t& B) const;
450 
451 #ifndef DOXYGEN_SHOULD_SKIP_THIS
452  void Schur_update
453  (const HSSFactors<scalar_t>& f, DenseM_t& Theta,
454  DenseM_t& DUB01, DenseM_t& Phi) const;
455  void Schur_product_direct
456  (const HSSFactors<scalar_t>& f,
457  const DenseM_t& Theta, const DenseM_t& DUB01,
458  const DenseM_t& Phi, const DenseM_t&_ThetaVhatC_or_VhatCPhiC,
459  const DenseM_t& R, DenseM_t& Sr, DenseM_t& Sc) const;
460  void Schur_product_indirect
461  (const HSSFactors<scalar_t>& f, const DenseM_t& DUB01,
462  const DenseM_t& R1, const DenseM_t& R2, const DenseM_t& Sr2,
463  const DenseM_t& Sc2, DenseM_t& Sr, DenseM_t& Sc) const;
464  void delete_trailing_block() override;
465 #endif // DOXYGEN_SHOULD_SKIP_THIS
466 
467  std::size_t rank() const override;
468  std::size_t memory() const override;
469  std::size_t nonzeros() const override;
470  std::size_t levels() const override;
471  void print_info
472  (std::ostream& out=std::cout,
473  std::size_t roff=0, std::size_t coff=0) const override;
474 
481  DenseM_t dense() const;
482 
483  void shift(scalar_t sigma) override;
484 
485  void draw
486  (std::ostream& of, std::size_t rlo=0, std::size_t clo=0) const override;
487 
488  protected:
489  HSSMatrix
490  (std::size_t m, std::size_t n, const opts_t& opts, bool active);
491  HSSMatrix(const HSSPartitionTree& t, const opts_t& opts, bool active);
492 
493  HSSBasisID<scalar_t> _U;
494  HSSBasisID<scalar_t> _V;
495  DenseM_t _D;
496  DenseM_t _B01;
497  DenseM_t _B10;
498 
499  void compress_original(const DenseM_t& A, const opts_t& opts);
500  void compress_original
501  (const mult_t& Amult, const elem_t& Aelem, const opts_t& opts);
502  void compress_stable(const DenseM_t& A, const opts_t& opts);
503  void compress_stable
504  (const mult_t& Amult, const elem_t& Aelem, const opts_t& opts);
505  void compress_hard_restart(const DenseM_t& A, const opts_t& opts);
506  void compress_hard_restart
507  (const mult_t& Amult, const elem_t& Aelem, const opts_t& opts);
508 
509  void compress_recursive_original
510  (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
511  const elem_t& Aelem, const opts_t& opts,
512  WorkCompress<scalar_t>& w, int dd, int depth) override;
513  void compress_recursive_stable
514  (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
515  const elem_t& Aelem, const opts_t& opts,
516  WorkCompress<scalar_t>& w, int d, int dd, int depth) override;
517  void compute_local_samples
518  (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
519  WorkCompress<scalar_t>& w, int d0, int d, int depth);
520  bool compute_U_V_bases
521  (DenseM_t& Sr, DenseM_t& Sc, const opts_t& opts,
522  WorkCompress<scalar_t>& w, int d, int depth);
523  void compute_U_basis_stable
524  (DenseM_t& Sr, const opts_t& opts,
525  WorkCompress<scalar_t>& w, int d, int dd, int depth);
526  void compute_V_basis_stable
527  (DenseM_t& Sc, const opts_t& opts,
528  WorkCompress<scalar_t>& w, int d, int dd, int depth);
529  void reduce_local_samples
530  (DenseM_t& Rr, DenseM_t& Rc, WorkCompress<scalar_t>& w,
531  int d0, int d, int depth);
532  bool update_orthogonal_basis
533  (const opts_t& opts, scalar_t& r_max_0,
534  const DenseM_t& S, DenseM_t& Q, int d, int dd,
535  bool untouched, int L, int depth);
536  void set_U_full_rank(WorkCompress<scalar_t>& w);
537  void set_V_full_rank(WorkCompress<scalar_t>& w);
538 
539  void compress_level_original
540  (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
541  const opts_t& opts, WorkCompress<scalar_t>& w,
542  int dd, int lvl, int depth) override;
543  void compress_level_stable
544  (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
545  const opts_t& opts, WorkCompress<scalar_t>& w,
546  int d, int dd, int lvl, int depth) override;
547  void get_extraction_indices
548  (std::vector<std::vector<std::size_t>>& I,
549  std::vector<std::vector<std::size_t>>& J,
550  const std::pair<std::size_t,std::size_t>& off,
551  WorkCompress<scalar_t>& w, int& self, int lvl) override;
552  void get_extraction_indices
553  (std::vector<std::vector<std::size_t>>& I,
554  std::vector<std::vector<std::size_t>>& J, std::vector<DenseM_t*>& B,
555  const std::pair<std::size_t,std::size_t>& off,
556  WorkCompress<scalar_t>& w, int& self, int lvl) override;
557  void extract_D_B
558  (const elem_t& Aelem, const opts_t& opts,
559  WorkCompress<scalar_t>& w, int lvl) override;
560 
561  void compress
562  (const kernel::Kernel<real_t>& K, const opts_t& opts);
563  void compress_recursive_ann
565  const elem_t& Aelem, const opts_t& opts,
566  WorkCompressANN<scalar_t>& w, int depth) override;
567  void compute_local_samples_ann
569  WorkCompressANN<scalar_t>& w, const elem_t& Aelem, const opts_t& opts);
570  bool compute_U_V_bases_ann
571  (DenseM_t& S, const opts_t& opts,
572  WorkCompressANN<scalar_t>& w, int depth);
573 
574  void factor_recursive
575  (HSSFactors<scalar_t>& ULV, WorkFactor<scalar_t>& w,
576  bool isroot, bool partial, int depth) const override;
577 
578  void apply_fwd
579  (const DenseM_t& b, WorkApply<scalar_t>& w, bool isroot,
580  int depth, std::atomic<long long int>& flops) const override;
581  void apply_bwd
582  (const DenseM_t& b, scalar_t beta, DenseM_t& c,
583  WorkApply<scalar_t>& w, bool isroot, int depth,
584  std::atomic<long long int>& flops) const override;
585  void applyT_fwd
586  (const DenseM_t& b, WorkApply<scalar_t>& w, bool isroot,
587  int depth, std::atomic<long long int>& flops) const override;
588  void applyT_bwd
589  (const DenseM_t& b, scalar_t beta, DenseM_t& c,
590  WorkApply<scalar_t>& w, bool isroot, int depth,
591  std::atomic<long long int>& flops) const override;
592 
593  void solve_fwd
594  (const HSSFactors<scalar_t>& ULV, const DenseM_t& b,
595  WorkSolve<scalar_t>& w,
596  bool partial, bool isroot, int depth) const override;
597  void solve_bwd
598  (const HSSFactors<scalar_t>& ULV, DenseM_t& x,
599  WorkSolve<scalar_t>& w,
600  bool isroot, int depth) const override;
601 
602  void extract_fwd
603  (WorkExtract<scalar_t>& w, bool odiag, int depth) const override;
604  void extract_bwd
605  (DenseM_t& B, WorkExtract<scalar_t>& w, int depth) const override;
606  void extract_bwd
607  (std::vector<Triplet<scalar_t>>& triplets,
608  WorkExtract<scalar_t>& w, int depth) const override;
609  void extract_bwd_internal(WorkExtract<scalar_t>& w, int depth) const;
610 
611  void apply_UV_big
612  (DenseM_t& Theta, DenseM_t& Uop, DenseM_t& Phi,
613  DenseM_t& Vop, const std::pair<std::size_t, std::size_t>& offset,
614  int depth, std::atomic<long long int>& flops) const override;
615  void apply_UtVt_big
616  (const DenseM_t& A, DenseM_t& UtA, DenseM_t& VtA,
617  const std::pair<std::size_t, std::size_t>& offset,
618  int depth, std::atomic<long long int>& flops) const override;
619 
620  void dense_recursive
621  (DenseM_t& A, WorkDense<scalar_t>& w,
622  bool isroot, int depth) const override;
623 
627  template<typename T> friend void apply_HSS
628  (Trans ta, const HSSMatrix<T>& a, const DenseMatrix<T>& b,
629  T beta, DenseMatrix<T>& c);
630 
634  template<typename T> friend void draw
635  (const HSSMatrix<T>& H, const std::string& name);
636 
637  friend class HSSMatrixMPI<scalar_t>;
638  };
639 
647  template<typename scalar_t>
648  void draw(const HSSMatrix<scalar_t>& H, const std::string& name);
649 
661  template<typename scalar_t> void apply_HSS
662  (Trans op, const HSSMatrix<scalar_t>& A, const DenseMatrix<scalar_t>& B,
663  scalar_t beta, DenseMatrix<scalar_t>& C);
664 
665  } // end namespace HSS
666 } // end namespace strumpack
667 
668 #endif // HSS_MATRIX_HPP
strumpack::HSS::HSSMatrix::extract_add
void extract_add(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B) const
strumpack::HSS::HSSMatrix::memory
std::size_t memory() const override
strumpack::HSS::HSSPartitionTree
The cluster tree, or partition tree that represents the matrix partitioning of an HSS matrix.
Definition: HSSPartitionTree.hpp:65
strumpack::HSS::draw
void draw(const HSSMatrix< scalar_t > &H, const std::string &name)
strumpack::HSS::HSSMatrix::draw
void draw(std::ostream &of, std::size_t rlo=0, std::size_t clo=0) const override
strumpack::HSS::HSSMatrix::compress
void compress(const DenseM_t &A, const opts_t &opts)
strumpack::HSS::HSSMatrix::print_info
void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const override
strumpack::HSS::HSSMatrix::levels
std::size_t levels() const override
strumpack::HSS::HSSMatrix::apply_HSS
friend void apply_HSS(Trans ta, const HSSMatrix< T > &a, const DenseMatrix< T > &b, T beta, DenseMatrix< T > &c)
strumpack::HSS::HSSOptions
Class containing several options for the HSS code and data-structures.
Definition: HSSOptions.hpp:117
strumpack::HSS::HSSMatrix::solve
void solve(const HSSFactors< scalar_t > &ULV, DenseM_t &b) const
strumpack::HSS::HSSMatrix::operator=
HSSMatrix< scalar_t > & operator=(const HSSMatrix< scalar_t > &other)
strumpack::HSS::HSSMatrix::backward_solve
void backward_solve(const HSSFactors< scalar_t > &ULV, WorkSolve< scalar_t > &w, DenseM_t &x) const override
strumpack
Definition: StrumpackOptions.hpp:42
strumpack::HSS::HSSMatrix::nonzeros
std::size_t nonzeros() const override
strumpack::HSS::apply_HSS
void apply_HSS(Trans op, const HSSMatrix< scalar_t > &A, const DenseMatrix< scalar_t > &B, scalar_t beta, DenseMatrix< scalar_t > &C)
strumpack::DenseMatrixWrapper
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition: DenseMatrix.hpp:946
Kernel.hpp
Definitions of several kernel functions, and helper routines. Also provides driver routines for kerne...
strumpack::HSS::HSSMatrix::partial_factor
HSSFactors< scalar_t > partial_factor() const
strumpack::HSS::HSSMatrix::compress_with_coordinates
void compress_with_coordinates(const DenseMatrix< real_t > &coords, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
HSSMatrixBase.hpp
This file contains the HSSMatrixBase class definition, an abstract class for HSS matrix representatio...
strumpack::DenseMatrix< scalar_t >
strumpack::HSS::HSSFactors
Contains data related to ULV factorization of an HSS matrix.
Definition: HSSExtra.hpp:161
strumpack::HSS::HSSMatrix::forward_solve
void forward_solve(const HSSFactors< scalar_t > &ULV, WorkSolve< scalar_t > &w, const DenseM_t &b, bool partial) const override
strumpack::HSS::HSSMatrix::shift
void shift(scalar_t sigma) override
strumpack::HSS::HSSMatrix::apply
DenseM_t apply(const DenseM_t &b) const
strumpack::HSS::HSSMatrix::applyC
DenseM_t applyC(const DenseM_t &b) const
strumpack::HSS::HSSMatrix::get
scalar_t get(std::size_t i, std::size_t j) const
strumpack::HSS::HSSMatrixBase::active
bool active() const
Definition: HSSMatrixBase.hpp:233
strumpack::HSS::HSSMatrix::reset
void reset() override
strumpack::HSS::HSSMatrix::child
const HSSMatrix< scalar_t > * child(int c) const
Definition: HSSMatrix.hpp:192
strumpack::HSS::HSSMatrix::rank
std::size_t rank() const override
strumpack::HSS::HSSMatrix::factor
HSSFactors< scalar_t > factor() const
strumpack::HSS::HSSMatrix
Class to represent a sequential/threaded Hierarchically Semi-Separable matrix.
Definition: HSSMatrix.hpp:75
strumpack::HSS::HSSMatrix::extract
DenseM_t extract(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J) const
strumpack::HSS::HSSMatrix::clone
std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const override
HSSOptions.hpp
Contains the HSSOptions class as well as general routines for HSS options.
strumpack::CompressionType::HSS
@ HSS
strumpack::HSS::HSSMatrix::HSSMatrix
HSSMatrix()
strumpack::HSS::HSSMatrix::dense
DenseM_t dense() const
strumpack::HSS::HSSMatrixMPI
Definition: HSSMatrix.hpp:54
strumpack::HSS::HSSMatrix::child
HSSMatrix< scalar_t > * child(int c)
Definition: HSSMatrix.hpp:202
strumpack::kernel::Kernel
Representation of a kernel matrix.
Definition: Kernel.hpp:73
strumpack::HSS::HSSMatrixBase
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition: HSSMatrixBase.hpp:79
strumpack::Trans
Trans
Definition: DenseMatrix.hpp:50
HSSPartitionTree.hpp
This file contains the HSSPartitionTree class definition.