HSSMatrix.hpp
Go to the documentation of this file.
1/*
2 * STRUMPACK -- STRUctured Matrices PACKage, Copyright (c) 2014, The
3 * Regents of the University of California, through Lawrence Berkeley
4 * National Laboratory (subject to receipt of any required approvals
5 * from the U.S. Dept. of Energy). All rights reserved.
6 *
7 * If you have questions about your rights to use or distribute this
8 * software, please contact Berkeley Lab's Technology Transfer
9 * Department at TTD@lbl.gov.
10 *
11 * NOTICE. This software is owned by the U.S. Department of Energy. As
12 * such, the U.S. Government has been granted for itself and others
13 * acting on its behalf a paid-up, nonexclusive, irrevocable,
14 * worldwide license in the Software to reproduce, prepare derivative
15 * works, and perform publicly and display publicly. Beginning five
16 * (5) years after the date permission to assert copyright is obtained
17 * from the U.S. Department of Energy, and subject to any subsequent
18 * five (5) year renewals, the U.S. Government is granted for itself
19 * and others acting on its behalf a paid-up, nonexclusive,
20 * irrevocable, worldwide license in the Software to reproduce,
21 * prepare derivative works, distribute copies to the public, perform
22 * publicly and display publicly, and to permit others to do so.
23 *
24 * Developers: Pieter Ghysels, Francois-Henry Rouet, Xiaoye S. Li.
25 * (Lawrence Berkeley National Lab, Computational Research
26 * Division).
27 *
28 */
37#ifndef HSS_MATRIX_HPP
38#define HSS_MATRIX_HPP
39
40#include <cassert>
41#include <functional>
42#include <string>
43
44#include "HSSBasisID.hpp"
45#include "HSSOptions.hpp"
46#include "HSSExtra.hpp"
47#include "HSSMatrixBase.hpp"
48#include "kernel/Kernel.hpp"
49#include "HSSMatrix.sketch.hpp"
50
51namespace strumpack {
52 namespace HSS {
53
54#ifndef DOXYGEN_SHOULD_SKIP_THIS
55 // forward declaration
56 template<typename scalar_t> class HSSMatrixMPI;
57#endif /* DOXYGEN_SHOULD_SKIP_THIS */
58
59
79 template<typename scalar_t> class HSSMatrix
80 : public HSSMatrixBase<scalar_t> {
81 using real_t = typename RealType<scalar_t>::value_type;
84 using elem_t = typename std::function
85 <void(const std::vector<std::size_t>& I,
86 const std::vector<std::size_t>& J, DenseM_t& B)>;
87 using mult_t = typename std::function
88 <void(DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc)>;
90
91 public:
96
114 HSSMatrix(const DenseM_t& A, const opts_t& opts);
115
131 HSSMatrix(std::size_t m, std::size_t n, const opts_t& opts);
132
144
155
162
169
174 HSSMatrix(HSSMatrix<scalar_t>&& other) = default;
175
181
186 std::unique_ptr<HSSMatrixBase<scalar_t>> clone() const override;
187
194 const HSSMatrix<scalar_t>* child(int c) const {
195 return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
196 }
197
205 return dynamic_cast<HSSMatrix<scalar_t>*>(this->ch_[c].get());
206 }
207
224 void compress(const DenseM_t& A, const opts_t& opts);
225
270 void compress(const std::function<void(DenseM_t& Rr,
271 DenseM_t& Rc,
272 DenseM_t& Sr,
273 DenseM_t& Sc)>& Amult,
274 const std::function<void(const std::vector<std::size_t>& I,
275 const std::vector<std::size_t>& J,
276 DenseM_t& B)>& Aelem,
277 const opts_t& opts);
278
279
307 const std::function
308 <void(const std::vector<std::size_t>& I,
309 const std::vector<std::size_t>& J,
310 DenseM_t& B)>& Aelem,
311 const opts_t& opts);
312
317 void reset() override;
318
322 void factor() override;
323
335
347 void solve(DenseM_t& b) const override;
348
364 void forward_solve(WorkSolve<scalar_t>& w, const DenseM_t& b,
365 bool partial) const override;
366
380 void backward_solve(WorkSolve<scalar_t>& w, DenseM_t& x) const override;
381
390 DenseM_t apply(const DenseM_t& b) const;
391
404 void mult(Trans op, const DenseM_t& x, DenseM_t& y) const override;
405
415 DenseM_t applyC(const DenseM_t& b) const;
416
427 scalar_t get(std::size_t i, std::size_t j) const;
428
439 DenseM_t extract(const std::vector<std::size_t>& I,
440 const std::vector<std::size_t>& J) const;
441
455 void extract_add(const std::vector<std::size_t>& I,
456 const std::vector<std::size_t>& J,
457 DenseM_t& B) const;
458
459#ifndef DOXYGEN_SHOULD_SKIP_THIS
460 void Schur_update(DenseM_t& Theta,
461 DenseM_t& DUB01,
462 DenseM_t& Phi) const;
463 void Schur_product_direct(const DenseM_t& Theta,
464 const DenseM_t& DUB01,
465 const DenseM_t& Phi,
466 const DenseM_t&_ThetaVhatC_or_VhatCPhiC,
467 const DenseM_t& R,
468 DenseM_t& Sr, DenseM_t& Sc) const;
469 void Schur_product_indirect(const DenseM_t& DUB01,
470 const DenseM_t& R1,
471 const DenseM_t& R2, const DenseM_t& Sr2,
472 const DenseM_t& Sc2,
473 DenseM_t& Sr, DenseM_t& Sc) const;
474 void delete_trailing_block() override;
475#endif // DOXYGEN_SHOULD_SKIP_THIS
476
477 std::size_t rank() const override;
478 std::size_t memory() const override;
479 std::size_t nonzeros() const override;
480 std::size_t levels() const override;
481
482 void print_info(std::ostream& out=std::cout,
483 std::size_t roff=0,
484 std::size_t coff=0) const override;
485
493
494 void shift(scalar_t sigma) override;
495
496 void draw(std::ostream& of,
497 std::size_t rlo=0, std::size_t clo=0) const override;
498
505 void write(const std::string& fname) const;
506
513 static HSSMatrix<scalar_t> read(const std::string& fname);
514
515 const HSSFactors<scalar_t>& ULV() { return this->ULV_; }
516
517 protected:
518 HSSMatrix(std::size_t m, std::size_t n,
519 const opts_t& opts, bool active);
521 const opts_t& opts, bool active);
522 HSSMatrix(std::ifstream& is);
523
524 HSSBasisID<scalar_t> U_, V_;
525 DenseM_t D_, B01_, B10_;
526
527 void compress_original(const DenseM_t& A,
528 const opts_t& opts);
529 void compress_original(const mult_t& Amult,
530 const elem_t& Aelem,
531 const opts_t& opts);
532 void compress_stable(const DenseM_t& A,
533 const opts_t& opts);
534 void compress_stable(const mult_t& Amult,
535 const elem_t& Aelem,
536 const opts_t& opts);
537 void compress_hard_restart(const DenseM_t& A,
538 const opts_t& opts);
539 void compress_hard_restart(const mult_t& Amult,
540 const elem_t& Aelem,
541 const opts_t& opts);
542
543 void compress_recursive_original(DenseM_t& Rr, DenseM_t& Rc,
544 DenseM_t& Sr, DenseM_t& Sc,
545 const elem_t& Aelem,
546 const opts_t& opts,
547 WorkCompress<scalar_t>& w,
548 int dd, int depth) override;
549 void compress_recursive_stable(DenseM_t& Rr, DenseM_t& Rc,
550 DenseM_t& Sr, DenseM_t& Sc,
551 const elem_t& Aelem,
552 const opts_t& opts,
553 WorkCompress<scalar_t>& w,
554 int d, int dd, int depth) override;
555 // SJLT_Matrix<scalar_t,int>* S=nullptr
556 void compute_local_samples(DenseM_t& Rr, DenseM_t& Rc,
557 DenseM_t& Sr, DenseM_t& Sc,
558 WorkCompress<scalar_t>& w,
559 int d0, int d, int depth,
560 SJLTMatrix<scalar_t, int>* S=nullptr);
561 bool compute_U_V_bases(DenseM_t& Sr, DenseM_t& Sc, const opts_t& opts,
562 WorkCompress<scalar_t>& w, int d, int depth);
563 void compute_U_basis_stable(DenseM_t& Sr, const opts_t& opts,
564 WorkCompress<scalar_t>& w,
565 int d, int dd, int depth);
566 void compute_V_basis_stable(DenseM_t& Sc, const opts_t& opts,
567 WorkCompress<scalar_t>& w,
568 int d, int dd, int depth);
569 void reduce_local_samples(DenseM_t& Rr, DenseM_t& Rc,
570 WorkCompress<scalar_t>& w,
571 int d0, int d, int depth);
572 bool update_orthogonal_basis(const opts_t& opts, scalar_t& r_max_0,
573 const DenseM_t& S, DenseM_t& Q,
574 int d, int dd, bool untouched,
575 int L, int depth);
576 void set_U_full_rank(WorkCompress<scalar_t>& w);
577 void set_V_full_rank(WorkCompress<scalar_t>& w);
578
579 void compress_level_original(DenseM_t& Rr, DenseM_t& Rc,
580 DenseM_t& Sr, DenseM_t& Sc,
581 const opts_t& opts,
582 WorkCompress<scalar_t>& w,
583 int dd, int lvl, int depth) override;
584 void compress_level_stable(DenseM_t& Rr, DenseM_t& Rc,
585 DenseM_t& Sr, DenseM_t& Sc,
586 const opts_t& opts,
587 WorkCompress<scalar_t>& w,
588 int d, int dd, int lvl, int depth) override;
589 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
590 std::vector<std::vector<std::size_t>>& J,
591 const std::pair<std::size_t,std::size_t>& off,
592 WorkCompress<scalar_t>& w,
593 int& self, int lvl) override;
594 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
595 std::vector<std::vector<std::size_t>>& J,
596 std::vector<DenseM_t*>& B,
597 const std::pair<std::size_t,std::size_t>& off,
598 WorkCompress<scalar_t>& w,
599 int& self, int lvl) override;
600 void extract_D_B(const elem_t& Aelem, const opts_t& opts,
601 WorkCompress<scalar_t>& w, int lvl) override;
602
603 void compress(const kernel::Kernel<real_t>& K, const opts_t& opts);
604 void compress_recursive_ann(DenseMatrix<std::uint32_t>& ann,
605 DenseMatrix<real_t>& scores,
606 const elem_t& Aelem, const opts_t& opts,
607 WorkCompressANN<scalar_t>& w,
608 int depth) override;
609 void compute_local_samples_ann(DenseMatrix<std::uint32_t>& ann,
610 DenseMatrix<real_t>& scores,
611 WorkCompressANN<scalar_t>& w,
612 const elem_t& Aelem, const opts_t& opts);
613 bool compute_U_V_bases_ann(DenseM_t& S, const opts_t& opts,
614 WorkCompressANN<scalar_t>& w, int depth);
615
616 void factor_recursive(WorkFactor<scalar_t>& w,
617 bool isroot, bool partial,
618 int depth) override;
619
620 void apply_fwd(const DenseM_t& b, WorkApply<scalar_t>& w,
621 bool isroot, int depth,
622 std::atomic<long long int>& flops) const override;
623 void apply_bwd(const DenseM_t& b, scalar_t beta, DenseM_t& c,
624 WorkApply<scalar_t>& w, bool isroot, int depth,
625 std::atomic<long long int>& flops) const override;
626 void applyT_fwd(const DenseM_t& b, WorkApply<scalar_t>& w, bool isroot,
627 int depth, std::atomic<long long int>& flops) const override;
628 void applyT_bwd(const DenseM_t& b, scalar_t beta, DenseM_t& c,
629 WorkApply<scalar_t>& w, bool isroot, int depth,
630 std::atomic<long long int>& flops) const override;
631
632 void solve_fwd(const DenseM_t& b, WorkSolve<scalar_t>& w,
633 bool partial, bool isroot, int depth) const override;
634 void solve_bwd(DenseM_t& x, WorkSolve<scalar_t>& w,
635 bool isroot, int depth) const override;
636
637 void extract_fwd(WorkExtract<scalar_t>& w,
638 bool odiag, int depth) const override;
639 void extract_bwd(DenseM_t& B, WorkExtract<scalar_t>& w,
640 int depth) const override;
641 void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
642 WorkExtract<scalar_t>& w, int depth) const override;
643 void extract_bwd_internal(WorkExtract<scalar_t>& w, int depth) const;
644
645 void apply_UV_big(DenseM_t& Theta, DenseM_t& Uop, DenseM_t& Phi,
646 DenseM_t& Vop,
647 const std::pair<std::size_t, std::size_t>& offset,
648 int depth, std::atomic<long long int>& flops)
649 const override;
650 void apply_UtVt_big(const DenseM_t& A, DenseM_t& UtA, DenseM_t& VtA,
651 const std::pair<std::size_t, std::size_t>& offset,
652 int depth, std::atomic<long long int>& flops)
653 const override;
654
655 void dense_recursive(DenseM_t& A, WorkDense<scalar_t>& w,
656 bool isroot, int depth) const override;
657
661 template<typename T> friend
662 void apply_HSS(Trans ta, const HSSMatrix<T>& a, const DenseMatrix<T>& b,
663 T beta, DenseMatrix<T>& c);
664
668 template<typename T> friend
669 void draw(const HSSMatrix<T>& H, const std::string& name);
670
671 void read(std::ifstream& is) override;
672 void write(std::ofstream& os) const override;
673
674 friend class HSSMatrixMPI<scalar_t>;
675
676 using HSSMatrixBase<scalar_t>::child;
677 };
678
686 template<typename scalar_t>
687 void draw(const HSSMatrix<scalar_t>& H, const std::string& name);
688
700 template<typename scalar_t> void
702 const DenseMatrix<scalar_t>& B,
703 scalar_t beta, DenseMatrix<scalar_t>& C);
704
705 } // end namespace HSS
706} // end namespace strumpack
707
708#endif // HSS_MATRIX_HPP
This file contains the HSSMatrixBase class definition, an abstract class for HSS matrix representatio...
Contains the HSSOptions class as well as general routines for HSS options.
Definitions of several kernel functions, and helper routines. Also provides driver routines for kerne...
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition: DenseMatrix.hpp:1015
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition: DenseMatrix.hpp:138
Contains data related to ULV factorization of an HSS matrix.
Definition: HSSExtra.hpp:161
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition: HSSMatrixBase.hpp:83
bool active() const
Definition: HSSMatrixBase.hpp:239
Distributed memory implementation of the HSS (Hierarchically Semi-Separable) matrix format.
Definition: HSSMatrixMPI.hpp:68
Class to represent a sequential/threaded Hierarchically Semi-Separable matrix.
Definition: HSSMatrix.hpp:80
void factor() override
void forward_solve(WorkSolve< scalar_t > &w, const DenseM_t &b, bool partial) const override
std::size_t levels() const override
void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const override
const HSSMatrix< scalar_t > * child(int c) const
Definition: HSSMatrix.hpp:194
HSSMatrix(const DenseM_t &A, const opts_t &opts)
friend void draw(const HSSMatrix< T > &H, const std::string &name)
void draw(std::ostream &of, std::size_t rlo=0, std::size_t clo=0) const override
void extract_add(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B) const
HSSMatrix< scalar_t > & operator=(HSSMatrix< scalar_t > &&other)=default
scalar_t get(std::size_t i, std::size_t j) const
HSSMatrix< scalar_t > & operator=(const HSSMatrix< scalar_t > &other)
void shift(scalar_t sigma) override
void compress_with_coordinates(const DenseMatrix< real_t > &coords, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
void compress(const std::function< void(DenseM_t &Rr, DenseM_t &Rc, DenseM_t &Sr, DenseM_t &Sc)> &Amult, const std::function< void(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseM_t &B)> &Aelem, const opts_t &opts)
std::size_t memory() const override
DenseM_t applyC(const DenseM_t &b) const
static HSSMatrix< scalar_t > read(const std::string &fname)
HSSMatrix(HSSMatrix< scalar_t > &&other)=default
std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const override
void write(const std::string &fname) const
HSSMatrix(const HSSMatrix< scalar_t > &other)
HSSMatrix(kernel::Kernel< real_t > &K, const opts_t &opts)
void backward_solve(WorkSolve< scalar_t > &w, DenseM_t &x) const override
friend void apply_HSS(Trans ta, const HSSMatrix< T > &a, const DenseMatrix< T > &b, T beta, DenseMatrix< T > &c)
DenseM_t extract(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J) const
void solve(DenseM_t &b) const override
void compress(const DenseM_t &A, const opts_t &opts)
HSSMatrix(std::size_t m, std::size_t n, const opts_t &opts)
std::size_t rank() const override
HSSMatrix(const structured::ClusterTree &t, const opts_t &opts)
DenseM_t dense() const
void mult(Trans op, const DenseM_t &x, DenseM_t &y) const override
HSSMatrix< scalar_t > * child(int c)
Definition: HSSMatrix.hpp:204
std::size_t nonzeros() const override
DenseM_t apply(const DenseM_t &b) const
Class containing several options for the HSS code and data-structures.
Definition: HSSOptions.hpp:152
Representation of a kernel matrix.
Definition: Kernel.hpp:73
The cluster tree, or partition tree that represents the partitioning of the rows or columns of a hier...
Definition: ClusterTree.hpp:67
void apply_HSS(Trans op, const HSSMatrix< scalar_t > &A, const DenseMatrix< scalar_t > &B, scalar_t beta, DenseMatrix< scalar_t > &C)
void draw(const HSSMatrix< scalar_t > &H, const std::string &name)
Definition: StrumpackOptions.hpp:43
Trans
Definition: DenseMatrix.hpp:51