HSSMatrix.hpp
Go to the documentation of this file.
1 /*
2  * STRUMPACK -- STRUctured Matrices PACKage, Copyright (c) 2014, The
3  * Regents of the University of California, through Lawrence Berkeley
4  * National Laboratory (subject to receipt of any required approvals
5  * from the U.S. Dept. of Energy). All rights reserved.
6  *
7  * If you have questions about your rights to use or distribute this
8  * software, please contact Berkeley Lab's Technology Transfer
9  * Department at TTD@lbl.gov.
10  *
11  * NOTICE. This software is owned by the U.S. Department of Energy. As
12  * such, the U.S. Government has been granted for itself and others
13  * acting on its behalf a paid-up, nonexclusive, irrevocable,
14  * worldwide license in the Software to reproduce, prepare derivative
15  * works, and perform publicly and display publicly. Beginning five
16  * (5) years after the date permission to assert copyright is obtained
17  * from the U.S. Department of Energy, and subject to any subsequent
18  * five (5) year renewals, the U.S. Government is granted for itself
19  * and others acting on its behalf a paid-up, nonexclusive,
20  * irrevocable, worldwide license in the Software to reproduce,
21  * prepare derivative works, distribute copies to the public, perform
22  * publicly and display publicly, and to permit others to do so.
23  *
24  * Developers: Pieter Ghysels, Francois-Henry Rouet, Xiaoye S. Li.
25  * (Lawrence Berkeley National Lab, Computational Research
26  * Division).
27  *
28  */
37 #ifndef HSS_MATRIX_HPP
38 #define HSS_MATRIX_HPP
39 
40 #include <cassert>
41 #include <functional>
42 #include <string>
43 
44 #include "HSSPartitionTree.hpp"
45 #include "HSSBasisID.hpp"
46 #include "HSSOptions.hpp"
47 #include "HSSExtra.hpp"
48 #include "HSSMatrixBase.hpp"
49 #include "kernel/Kernel.hpp"
50 
51 namespace strumpack {
52  namespace HSS {
53 
54  // forward declaration
55  template<typename scalar_t> class HSSMatrixMPI;
56 
76  template<typename scalar_t> class HSSMatrix
77  : public HSSMatrixBase<scalar_t> {
78  using real_t = typename RealType<scalar_t>::value_type;
81  using elem_t = typename std::function
82  <void(const std::vector<std::size_t>& I,
83  const std::vector<std::size_t>& J, DenseM_t& B)>;
84  using mult_t = typename std::function
85  <void(DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc)>;
87 
88  public:
92  HSSMatrix();
93 
111  HSSMatrix(const DenseM_t& A, const opts_t& opts);
112 
128  HSSMatrix(std::size_t m, std::size_t n, const opts_t& opts);
129 
140  HSSMatrix(const HSSPartitionTree& t, const opts_t& opts);
141 
151  HSSMatrix(kernel::Kernel<real_t>& K, const opts_t& opts);
152 
158  HSSMatrix(const HSSMatrix<scalar_t>& other);
159 
166 
171  HSSMatrix(HSSMatrix<scalar_t>&& other) = default;
172 
178 
183  std::unique_ptr<HSSMatrixBase<scalar_t>> clone() const override;
184 
191  const HSSMatrix<scalar_t>* child(int c) const {
192  return dynamic_cast<HSSMatrix<scalar_t>*>(this->_ch[c].get());
193  }
194 
202  return dynamic_cast<HSSMatrix<scalar_t>*>(this->_ch[c].get());
203  }
204 
221  void compress(const DenseM_t& A, const opts_t& opts);
222 
267  void compress
268  (const std::function
269  <void(DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc)>& Amult,
270  const std::function
271  <void(const std::vector<std::size_t>& I,
272  const std::vector<std::size_t>& J, DenseM_t& B)>& Aelem,
273  const opts_t& opts);
274 
275 
303  (const DenseMatrix<real_t>& coords,
304  const std::function
305  <void(const std::vector<std::size_t>& I,
306  const std::vector<std::size_t>& J, DenseM_t& B)>& Aelem,
307  const opts_t& opts);
308 
313  void reset() override;
314 
321 
334 
346  void solve(const HSSFactors<scalar_t>& ULV, DenseM_t& b) const;
347 
364  void forward_solve
365  (const HSSFactors<scalar_t>& ULV, WorkSolve<scalar_t>& w,
366  const DenseM_t& b, bool partial) const override;
367 
382  void backward_solve
383  (const HSSFactors<scalar_t>& ULV,
384  WorkSolve<scalar_t>& w, DenseM_t& x) const override;
385 
394  DenseM_t apply(const DenseM_t& b) const;
395 
405  DenseM_t applyC(const DenseM_t& b) const;
406 
417  scalar_t get(std::size_t i, std::size_t j) const;
418 
429  DenseM_t extract
430  (const std::vector<std::size_t>& I,
431  const std::vector<std::size_t>& J) const;
432 
446  void extract_add
447  (const std::vector<std::size_t>& I, const std::vector<std::size_t>& J,
448  DenseM_t& B) const;
449 
450 #ifndef DOXYGEN_SHOULD_SKIP_THIS
451  void Schur_update
452  (const HSSFactors<scalar_t>& f, DenseM_t& Theta,
453  DenseM_t& DUB01, DenseM_t& Phi) const;
454  void Schur_product_direct
455  (const HSSFactors<scalar_t>& f,
456  const DenseM_t& Theta, const DenseM_t& DUB01,
457  const DenseM_t& Phi, const DenseM_t&_ThetaVhatC_or_VhatCPhiC,
458  const DenseM_t& R, DenseM_t& Sr, DenseM_t& Sc) const;
459  void Schur_product_indirect
460  (const HSSFactors<scalar_t>& f, const DenseM_t& DUB01,
461  const DenseM_t& R1, const DenseM_t& R2, const DenseM_t& Sr2,
462  const DenseM_t& Sc2, DenseM_t& Sr, DenseM_t& Sc) const;
463  void delete_trailing_block() override;
464 #endif // DOXYGEN_SHOULD_SKIP_THIS
465 
466  std::size_t rank() const override;
467  std::size_t memory() const override;
468  std::size_t nonzeros() const override;
469  std::size_t levels() const override;
470  void print_info
471  (std::ostream& out=std::cout,
472  std::size_t roff=0, std::size_t coff=0) const override;
473 
480  DenseM_t dense() const;
481 
482  void shift(scalar_t sigma) override;
483 
484  void draw(std::ostream& of,
485  std::size_t rlo=0, std::size_t clo=0) const override;
486 
493  void write(const std::string& fname) const;
494 
501  static HSSMatrix<scalar_t> read(const std::string& fname);
502 
503  protected:
504  HSSMatrix(std::size_t m, std::size_t n,
505  const opts_t& opts, bool active);
506  HSSMatrix(const HSSPartitionTree& t, const opts_t& opts, bool active);
507  HSSMatrix(std::ifstream& is);
508 
509  HSSBasisID<scalar_t> _U, _V;
510  DenseM_t _D, _B01, _B10;
511 
512  void compress_original(const DenseM_t& A, const opts_t& opts);
513  void compress_original
514  (const mult_t& Amult, const elem_t& Aelem, const opts_t& opts);
515  void compress_stable(const DenseM_t& A, const opts_t& opts);
516  void compress_stable
517  (const mult_t& Amult, const elem_t& Aelem, const opts_t& opts);
518  void compress_hard_restart(const DenseM_t& A, const opts_t& opts);
519  void compress_hard_restart
520  (const mult_t& Amult, const elem_t& Aelem, const opts_t& opts);
521 
522  void compress_recursive_original
523  (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
524  const elem_t& Aelem, const opts_t& opts,
525  WorkCompress<scalar_t>& w, int dd, int depth) override;
526  void compress_recursive_stable
527  (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
528  const elem_t& Aelem, const opts_t& opts,
529  WorkCompress<scalar_t>& w, int d, int dd, int depth) override;
530  void compute_local_samples
531  (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
532  WorkCompress<scalar_t>& w, int d0, int d, int depth);
533  bool compute_U_V_bases
534  (DenseM_t& Sr, DenseM_t& Sc, const opts_t& opts,
535  WorkCompress<scalar_t>& w, int d, int depth);
536  void compute_U_basis_stable
537  (DenseM_t& Sr, const opts_t& opts,
538  WorkCompress<scalar_t>& w, int d, int dd, int depth);
539  void compute_V_basis_stable
540  (DenseM_t& Sc, const opts_t& opts,
541  WorkCompress<scalar_t>& w, int d, int dd, int depth);
542  void reduce_local_samples
543  (DenseM_t& Rr, DenseM_t& Rc, WorkCompress<scalar_t>& w,
544  int d0, int d, int depth);
545  bool update_orthogonal_basis
546  (const opts_t& opts, scalar_t& r_max_0,
547  const DenseM_t& S, DenseM_t& Q, int d, int dd,
548  bool untouched, int L, int depth);
549  void set_U_full_rank(WorkCompress<scalar_t>& w);
550  void set_V_full_rank(WorkCompress<scalar_t>& w);
551 
552  void compress_level_original
553  (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
554  const opts_t& opts, WorkCompress<scalar_t>& w,
555  int dd, int lvl, int depth) override;
556  void compress_level_stable
557  (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
558  const opts_t& opts, WorkCompress<scalar_t>& w,
559  int d, int dd, int lvl, int depth) override;
560  void get_extraction_indices
561  (std::vector<std::vector<std::size_t>>& I,
562  std::vector<std::vector<std::size_t>>& J,
563  const std::pair<std::size_t,std::size_t>& off,
564  WorkCompress<scalar_t>& w, int& self, int lvl) override;
565  void get_extraction_indices
566  (std::vector<std::vector<std::size_t>>& I,
567  std::vector<std::vector<std::size_t>>& J, std::vector<DenseM_t*>& B,
568  const std::pair<std::size_t,std::size_t>& off,
569  WorkCompress<scalar_t>& w, int& self, int lvl) override;
570  void extract_D_B
571  (const elem_t& Aelem, const opts_t& opts,
572  WorkCompress<scalar_t>& w, int lvl) override;
573 
574  void compress
575  (const kernel::Kernel<real_t>& K, const opts_t& opts);
576  void compress_recursive_ann
578  const elem_t& Aelem, const opts_t& opts,
579  WorkCompressANN<scalar_t>& w, int depth) override;
580  void compute_local_samples_ann
582  WorkCompressANN<scalar_t>& w, const elem_t& Aelem, const opts_t& opts);
583  bool compute_U_V_bases_ann
584  (DenseM_t& S, const opts_t& opts,
585  WorkCompressANN<scalar_t>& w, int depth);
586 
587  void factor_recursive
588  (HSSFactors<scalar_t>& ULV, WorkFactor<scalar_t>& w,
589  bool isroot, bool partial, int depth) const override;
590 
591  void apply_fwd
592  (const DenseM_t& b, WorkApply<scalar_t>& w, bool isroot,
593  int depth, std::atomic<long long int>& flops) const override;
594  void apply_bwd
595  (const DenseM_t& b, scalar_t beta, DenseM_t& c,
596  WorkApply<scalar_t>& w, bool isroot, int depth,
597  std::atomic<long long int>& flops) const override;
598  void applyT_fwd
599  (const DenseM_t& b, WorkApply<scalar_t>& w, bool isroot,
600  int depth, std::atomic<long long int>& flops) const override;
601  void applyT_bwd
602  (const DenseM_t& b, scalar_t beta, DenseM_t& c,
603  WorkApply<scalar_t>& w, bool isroot, int depth,
604  std::atomic<long long int>& flops) const override;
605 
606  void solve_fwd
607  (const HSSFactors<scalar_t>& ULV, const DenseM_t& b,
608  WorkSolve<scalar_t>& w,
609  bool partial, bool isroot, int depth) const override;
610  void solve_bwd
611  (const HSSFactors<scalar_t>& ULV, DenseM_t& x,
612  WorkSolve<scalar_t>& w,
613  bool isroot, int depth) const override;
614 
615  void extract_fwd
616  (WorkExtract<scalar_t>& w, bool odiag, int depth) const override;
617  void extract_bwd
618  (DenseM_t& B, WorkExtract<scalar_t>& w, int depth) const override;
619  void extract_bwd
620  (std::vector<Triplet<scalar_t>>& triplets,
621  WorkExtract<scalar_t>& w, int depth) const override;
622  void extract_bwd_internal(WorkExtract<scalar_t>& w, int depth) const;
623 
624  void apply_UV_big
625  (DenseM_t& Theta, DenseM_t& Uop, DenseM_t& Phi,
626  DenseM_t& Vop, const std::pair<std::size_t, std::size_t>& offset,
627  int depth, std::atomic<long long int>& flops) const override;
628  void apply_UtVt_big
629  (const DenseM_t& A, DenseM_t& UtA, DenseM_t& VtA,
630  const std::pair<std::size_t, std::size_t>& offset,
631  int depth, std::atomic<long long int>& flops) const override;
632 
633  void dense_recursive
634  (DenseM_t& A, WorkDense<scalar_t>& w,
635  bool isroot, int depth) const override;
636 
640  template<typename T> friend void apply_HSS
641  (Trans ta, const HSSMatrix<T>& a, const DenseMatrix<T>& b,
642  T beta, DenseMatrix<T>& c);
643 
647  template<typename T> friend void draw
648  (const HSSMatrix<T>& H, const std::string& name);
649 
650  void read(std::ifstream& is) override;
651  void write(std::ofstream& os) const override;
652 
653  friend class HSSMatrixMPI<scalar_t>;
654  };
655 
663  template<typename scalar_t>
664  void draw(const HSSMatrix<scalar_t>& H, const std::string& name);
665 
677  template<typename scalar_t> void apply_HSS
678  (Trans op, const HSSMatrix<scalar_t>& A, const DenseMatrix<scalar_t>& B,
679  scalar_t beta, DenseMatrix<scalar_t>& C);
680 
681  } // end namespace HSS
682 } // end namespace strumpack
683 
684 #endif // HSS_MATRIX_HPP
strumpack::HSS::HSSMatrix::extract_add
void extract_add(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B) const
strumpack::HSS::HSSMatrix::memory
std::size_t memory() const override
strumpack::HSS::HSSPartitionTree
The cluster tree, or partition tree that represents the matrix partitioning of an HSS matrix.
Definition: HSSPartitionTree.hpp:65
strumpack::HSS::draw
void draw(const HSSMatrix< scalar_t > &H, const std::string &name)
strumpack::HSS::HSSMatrix::draw
void draw(std::ostream &of, std::size_t rlo=0, std::size_t clo=0) const override
strumpack::HSS::HSSMatrix::compress
void compress(const DenseM_t &A, const opts_t &opts)
strumpack::HSS::HSSMatrix::print_info
void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const override
strumpack::HSS::HSSMatrix::levels
std::size_t levels() const override
strumpack::HSS::HSSMatrix::apply_HSS
friend void apply_HSS(Trans ta, const HSSMatrix< T > &a, const DenseMatrix< T > &b, T beta, DenseMatrix< T > &c)
strumpack::HSS::HSSOptions
Class containing several options for the HSS code and data-structures.
Definition: HSSOptions.hpp:117
strumpack::HSS::HSSMatrix::solve
void solve(const HSSFactors< scalar_t > &ULV, DenseM_t &b) const
strumpack::HSS::HSSMatrix::operator=
HSSMatrix< scalar_t > & operator=(const HSSMatrix< scalar_t > &other)
strumpack::HSS::HSSMatrix::backward_solve
void backward_solve(const HSSFactors< scalar_t > &ULV, WorkSolve< scalar_t > &w, DenseM_t &x) const override
strumpack
Definition: StrumpackOptions.hpp:42
strumpack::HSS::HSSMatrix::nonzeros
std::size_t nonzeros() const override
strumpack::HSS::apply_HSS
void apply_HSS(Trans op, const HSSMatrix< scalar_t > &A, const DenseMatrix< scalar_t > &B, scalar_t beta, DenseMatrix< scalar_t > &C)
strumpack::DenseMatrixWrapper
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition: DenseMatrix.hpp:991
Kernel.hpp
Definitions of several kernel functions, and helper routines. Also provides driver routines for kerne...
strumpack::HSS::HSSMatrix::partial_factor
HSSFactors< scalar_t > partial_factor() const
strumpack::HSS::HSSMatrix::compress_with_coordinates
void compress_with_coordinates(const DenseMatrix< real_t > &coords, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
HSSMatrixBase.hpp
This file contains the HSSMatrixBase class definition, an abstract class for HSS matrix representatio...
strumpack::DenseMatrix< scalar_t >
strumpack::HSS::HSSFactors
Contains data related to ULV factorization of an HSS matrix.
Definition: HSSExtra.hpp:161
strumpack::HSS::HSSMatrix::forward_solve
void forward_solve(const HSSFactors< scalar_t > &ULV, WorkSolve< scalar_t > &w, const DenseM_t &b, bool partial) const override
strumpack::HSS::HSSMatrix::shift
void shift(scalar_t sigma) override
strumpack::HSS::HSSMatrix::apply
DenseM_t apply(const DenseM_t &b) const
strumpack::HSS::HSSMatrix::applyC
DenseM_t applyC(const DenseM_t &b) const
strumpack::HSS::HSSMatrix::get
scalar_t get(std::size_t i, std::size_t j) const
strumpack::HSS::HSSMatrixBase::active
bool active() const
Definition: HSSMatrixBase.hpp:235
strumpack::HSS::HSSMatrix::reset
void reset() override
strumpack::HSS::HSSMatrix::child
const HSSMatrix< scalar_t > * child(int c) const
Definition: HSSMatrix.hpp:191
strumpack::HSS::HSSMatrix::rank
std::size_t rank() const override
strumpack::HSS::HSSMatrix::factor
HSSFactors< scalar_t > factor() const
strumpack::HSS::HSSMatrix
Class to represent a sequential/threaded Hierarchically Semi-Separable matrix.
Definition: HSSMatrix.hpp:76
strumpack::HSS::HSSMatrix::write
void write(const std::string &fname) const
strumpack::HSS::HSSMatrix::extract
DenseM_t extract(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J) const
strumpack::HSS::HSSMatrix::clone
std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const override
HSSOptions.hpp
Contains the HSSOptions class as well as general routines for HSS options.
strumpack::CompressionType::HSS
@ HSS
strumpack::HSS::HSSMatrix::HSSMatrix
HSSMatrix()
strumpack::HSS::HSSMatrix::dense
DenseM_t dense() const
strumpack::HSS::HSSMatrixMPI
Definition: HSSMatrix.hpp:55
strumpack::HSS::HSSMatrix::child
HSSMatrix< scalar_t > * child(int c)
Definition: HSSMatrix.hpp:201
strumpack::kernel::Kernel
Representation of a kernel matrix.
Definition: Kernel.hpp:73
strumpack::HSS::HSSMatrixBase
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition: HSSMatrixBase.hpp:81
strumpack::Trans
Trans
Definition: DenseMatrix.hpp:50
HSSPartitionTree.hpp
This file contains the HSSPartitionTree class definition.
strumpack::HSS::HSSMatrix::read
static HSSMatrix< scalar_t > read(const std::string &fname)