HSSMatrixBase.hpp
Go to the documentation of this file.
1/*
2 * STRUMPACK -- STRUctured Matrices PACKage, Copyright (c) 2014, The
3 * Regents of the University of California, through Lawrence Berkeley
4 * National Laboratory (subject to receipt of any required approvals
5 * from the U.S. Dept. of Energy). All rights reserved.
6 *
7 * If you have questions about your rights to use or distribute this
8 * software, please contact Berkeley Lab's Technology Transfer
9 * Department at TTD@lbl.gov.
10 *
11 * NOTICE. This software is owned by the U.S. Department of Energy. As
12 * such, the U.S. Government has been granted for itself and others
13 * acting on its behalf a paid-up, nonexclusive, irrevocable,
14 * worldwide license in the Software to reproduce, prepare derivative
15 * works, and perform publicly and display publicly. Beginning five
16 * (5) years after the date permission to assert copyright is obtained
17 * from the U.S. Department of Energy, and subject to any subsequent
18 * five (5) year renewals, the U.S. Government is granted for itself
19 * and others acting on its behalf a paid-up, nonexclusive,
20 * irrevocable, worldwide license in the Software to reproduce,
21 * prepare derivative works, distribute copies to the public, perform
22 * publicly and display publicly, and to permit others to do so.
23 *
24 * Developers: Pieter Ghysels, Francois-Henry Rouet, Xiaoye S. Li.
25 * (Lawrence Berkeley National Lab, Computational Research
26 * Division).
27 *
28 */
34#ifndef HSS_MATRIX_BASE_HPP
35#define HSS_MATRIX_BASE_HPP
36
37#include <cassert>
38#include <iostream>
39#include <fstream>
40#include <string>
41#include <vector>
42#include <functional>
43
44#include "dense/DenseMatrix.hpp"
45#include "misc/Triplet.hpp"
46#include "HSSOptions.hpp"
47#include "HSSExtra.hpp"
49#if defined(STRUMPACK_USE_MPI)
51#include "HSSExtraMPI.hpp"
52#include "HSSMatrixMPI.hpp"
53#endif
54
55namespace strumpack {
56 namespace HSS {
57
58#ifndef DOXYGEN_SHOULD_SKIP_THIS
59 template<typename scalar_t> class HSSMatrix;
60#if defined(STRUMPACK_USE_MPI)
61 template<typename scalar_t> class HSSMatrixMPI;
62 template<typename scalar_t> class DistSubLeaf;
63 template<typename scalar_t> class DistSamples;
64#endif //defined(STRUMPACK_USE_MPI)
65#endif //DOXYGEN_SHOULD_SKIP_THIS
66
67
82 template<typename scalar_t> class HSSMatrixBase
83 : public structured::StructuredMatrix<scalar_t> {
84 using real_t = typename RealType<scalar_t>::value_type;
87 using elem_t = typename std::function
88 <void(const std::vector<std::size_t>& I,
89 const std::vector<std::size_t>& J, DenseM_t& B)>;
91#if defined(STRUMPACK_USE_MPI)
94 using delem_t = typename std::function
95 <void(const std::vector<std::size_t>& I,
96 const std::vector<std::size_t>& J, DistM_t& B)>;
97#endif //defined(STRUMPACK_USE_MPI)
98
99 public:
108 HSSMatrixBase(std::size_t m, std::size_t n, bool active);
109
113 virtual ~HSSMatrixBase() = default;
114
120
127
133
140
147 virtual std::unique_ptr<HSSMatrixBase<scalar_t>> clone() const = 0;
148
155 std::pair<std::size_t,std::size_t> dims() const {
156 return std::make_pair(rows_, cols_);
157 }
158
163 std::size_t rows() const override { return rows_; }
164
169 std::size_t cols() const override { return cols_; }
170
175 bool leaf() const { return ch_.empty(); }
176
177 virtual std::size_t factor_nonzeros() const;
178
188 const HSSMatrixBase<scalar_t>& child(int c) const {
189 assert(c>=0 && c<int(ch_.size())); return *(ch_[c]);
190 }
191
202 assert(c>=0 && c<int(ch_.size())); return *(ch_[c]);
203 }
204
213 bool is_compressed() const {
214 return U_state_ == State::COMPRESSED &&
215 V_state_ == State::COMPRESSED;
216 }
217
228 bool is_untouched() const {
229 return U_state_ == State::UNTOUCHED &&
230 V_state_ == State::UNTOUCHED;
231 }
232
239 bool active() const { return active_; }
240
246 virtual std::size_t levels() const = 0;
247
260 virtual void print_info(std::ostream &out=std::cout,
261 std::size_t roff=0,
262 std::size_t coff=0) const = 0;
263
273 void set_openmp_task_depth(int depth) { openmp_task_depth_ = depth; }
274
275#ifndef DOXYGEN_SHOULD_SKIP_THIS
276 virtual void delete_trailing_block() { if (ch_.size()==2) ch_.resize(1); }
277 virtual void reset() {
278 U_state_ = V_state_ = State::UNTOUCHED;
279 U_rank_ = U_rows_ = V_rank_ = V_rows_ = 0;
280 for (auto& c : ch_) c->reset();
281 }
282#endif
283
291 virtual void shift(scalar_t sigma) override = 0;
292
299 virtual void draw(std::ostream& of,
300 std::size_t rlo,
301 std::size_t clo) const {}
302
303#if defined(STRUMPACK_USE_MPI)
304 virtual void forward_solve(WorkSolveMPI<scalar_t>& w,
305 const DistM_t& b, bool partial) const;
306 virtual void backward_solve(WorkSolveMPI<scalar_t>& w,
307 DistM_t& x) const;
308
309 virtual const BLACSGrid* grid() const { return nullptr; }
310 virtual const BLACSGrid* grid(const BLACSGrid* local_grid) const {
311 return active() ? local_grid : nullptr;
312 }
313 virtual const BLACSGrid* grid_local() const { return nullptr; }
314 virtual int Ptotal() const { return 1; }
315 virtual int Pactive() const { return 1; }
316
317 virtual void to_block_row(const DistM_t& A, DenseM_t& sub_A,
318 DistM_t& leaf_A) const;
319 virtual void allocate_block_row(int d, DenseM_t& sub_A,
320 DistM_t& leaf_A) const;
321 virtual void from_block_row(DistM_t& A, const DenseM_t& sub_A,
322 const DistM_t& leaf_A,
323 const BLACSGrid* lg) const;
324#endif //defined(STRUMPACK_USE_MPI)
325
326 protected:
327 std::size_t rows_, cols_;
328
329 // TODO store children array in the sub-class???
330 std::vector<std::unique_ptr<HSSMatrixBase<scalar_t>>> ch_;
331 State U_state_, V_state_;
332 int openmp_task_depth_;
333 bool active_;
334
335 int U_rank_ = 0, U_rows_ = 0, V_rank_ = 0, V_rows_ = 0;
336
337 // Used to redistribute the original 2D block cyclic matrix
338 // according to the HSS tree
339 DenseM_t Asub_;
340
341 HSSFactors<scalar_t> ULV_;
342#if defined(STRUMPACK_USE_MPI)
343 HSSFactorsMPI<scalar_t> ULV_mpi_;
344#endif //defined(STRUMPACK_USE_MPI)
345
346 virtual std::size_t U_rank() const { return U_rank_; }
347 virtual std::size_t V_rank() const { return V_rank_; }
348 virtual std::size_t U_rows() const { return U_rows_; }
349 virtual std::size_t V_rows() const { return V_rows_; }
350
351 virtual void
352 compress_recursive_original(DenseM_t& Rr, DenseM_t& Rc,
353 DenseM_t& Sr, DenseM_t& Sc,
354 const elem_t& Aelem, const opts_t& opts,
355 WorkCompress<scalar_t>& w,
356 int dd, int depth) {}
357 virtual void
358 compress_recursive_stable(DenseM_t& Rr, DenseM_t& Rc,
359 DenseM_t& Sr, DenseM_t& Sc,
360 const elem_t& Aelem, const opts_t& opts,
361 WorkCompress<scalar_t>& w,
362 int d, int dd, int depth) {}
363 virtual void
364 compress_level_original(DenseM_t& Rr, DenseM_t& Rc,
365 DenseM_t& Sr, DenseM_t& Sc,
366 const opts_t& opts, WorkCompress<scalar_t>& w,
367 int dd, int lvl, int depth) {}
368 virtual void
369 compress_level_stable(DenseM_t& Rr, DenseM_t& Rc,
370 DenseM_t& Sr, DenseM_t& Sc,
371 const opts_t& opts, WorkCompress<scalar_t>& w,
372 int d, int dd, int lvl, int depth) {}
373 virtual void
374 compress_recursive_ann(DenseMatrix<std::uint32_t>& ann,
375 DenseMatrix<real_t>& scores,
376 const elem_t& Aelem, const opts_t& opts,
377 WorkCompressANN<scalar_t>& w, int depth) {}
378
379 virtual void
380 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
381 std::vector<std::vector<std::size_t>>& J,
382 const std::pair<std::size_t,std::size_t>& off,
383 WorkCompress<scalar_t>& w,
384 int& self, int lvl) {}
385
386 virtual void
387 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
388 std::vector<std::vector<std::size_t>>& J,
389 std::vector<DenseM_t*>& B,
390 const std::pair<std::size_t,std::size_t>& off,
391 WorkCompress<scalar_t>& w,
392 int& self, int lvl) {}
393 virtual void extract_D_B(const elem_t& Aelem, const opts_t& opts,
394 WorkCompress<scalar_t>& w, int lvl) {}
395
396 virtual void factor_recursive(WorkFactor<scalar_t>& w,
397 bool isroot, bool partial,
398 int depth) {}
399
400 virtual void apply_fwd(const DenseM_t& b, WorkApply<scalar_t>& w,
401 bool isroot, int depth,
402 std::atomic<long long int>& flops) const {}
403 virtual void apply_bwd(const DenseM_t& b, scalar_t beta,
404 DenseM_t& c, WorkApply<scalar_t>& w,
405 bool isroot, int depth,
406 std::atomic<long long int>& flops) const {}
407 virtual void applyT_fwd(const DenseM_t& b, WorkApply<scalar_t>& w,
408 bool isroot, int depth,
409 std::atomic<long long int>& flops) const {}
410 virtual void applyT_bwd(const DenseM_t& b, scalar_t beta,
411 DenseM_t& c, WorkApply<scalar_t>& w,
412 bool isroot, int depth,
413 std::atomic<long long int>& flops) const {}
414
415 virtual void forward_solve(WorkSolve<scalar_t>& w,
416 const DenseMatrix<scalar_t>& b,
417 bool partial) const {}
418 virtual void backward_solve(WorkSolve<scalar_t>& w,
419 DenseMatrix<scalar_t>& b) const {}
420 virtual void solve_fwd(const DenseM_t& b,
421 WorkSolve<scalar_t>& w, bool partial,
422 bool isroot, int depth) const {}
423 virtual void solve_bwd(DenseM_t& x, WorkSolve<scalar_t>& w,
424 bool isroot, int depth) const {}
425
426 virtual void extract_fwd(WorkExtract<scalar_t>& w,
427 bool odiag, int depth) const {}
428 virtual void extract_bwd(DenseMatrix<scalar_t>& B,
429 WorkExtract<scalar_t>& w,
430 int depth) const {}
431 virtual void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
432 WorkExtract<scalar_t>& w, int depth) const {}
433
434 virtual void apply_UV_big(DenseM_t& Theta, DenseM_t& Uop,
435 DenseM_t& Phi, DenseM_t& Vop,
436 const std::pair<std::size_t,std::size_t>& offset,
437 int depth,
438 std::atomic<long long int>& flops) const {}
439 virtual void apply_UtVt_big(const DenseM_t& A, DenseM_t& UtA,
440 DenseM_t& VtA,
441 const std::pair<std::size_t, std::size_t>& offset,
442 int depth,
443 std::atomic<long long int>& flops) const {}
444
445 virtual void dense_recursive(DenseM_t& A, WorkDense<scalar_t>& w,
446 bool isroot, int depth) const {}
447
448 virtual void read(std::ifstream& os) {
449 std::cerr << "ERROR read_HSS_node not implemented" << std::endl;
450 }
451 virtual void write(std::ofstream& os) const {
452 std::cerr << "ERROR write_HSS_node not implemented" << std::endl;
453 }
454
455 friend class HSSMatrix<scalar_t>;
456
457#if defined(STRUMPACK_USE_MPI)
458 using delemw_t = typename std::function
459 <void(const std::vector<std::size_t>& I,
460 const std::vector<std::size_t>& J,
461 DistM_t& B, DistM_t& A,
462 std::size_t rlo, std::size_t clo,
463 MPI_Comm comm)>;
464
465
466 virtual void
467 compress_recursive_original(DistSamples<scalar_t>& RS,
468 const delemw_t& Aelem,
469 const opts_t& opts,
470 WorkCompressMPI<scalar_t>& w, int dd);
471 virtual void
472 compress_recursive_stable(DistSamples<scalar_t>& RS,
473 const delemw_t& Aelem,
474 const opts_t& opts,
475 WorkCompressMPI<scalar_t>& w, int d, int dd);
476 virtual void
477 compress_level_original(DistSamples<scalar_t>& RS, const opts_t& opts,
478 WorkCompressMPI<scalar_t>& w, int dd, int lvl);
479 virtual void
480 compress_level_stable(DistSamples<scalar_t>& RS, const opts_t& opts,
481 WorkCompressMPI<scalar_t>& w,
482 int d, int dd, int lvl);
483 virtual void
484 compress_recursive_ann(DenseMatrix<std::uint32_t>& ann,
485 DenseMatrix<real_t>& scores,
486 const delemw_t& Aelem,
487 WorkCompressMPIANN<scalar_t>& w,
488 const opts_t& opts, const BLACSGrid* lg);
489
490 virtual void
491 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
492 std::vector<std::vector<std::size_t>>& J,
493 WorkCompressMPI<scalar_t>& w,
494 int& self, int lvl);
495 virtual void
496 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
497 std::vector<std::vector<std::size_t>>& J,
498 std::vector<DistMW_t>& B,
499 const BLACSGrid* lg,
500 WorkCompressMPI<scalar_t>& w,
501 int& self, int lvl);
502 virtual void extract_D_B(const delemw_t& Aelem, const BLACSGrid* lg,
503 const opts_t& opts,
504 WorkCompressMPI<scalar_t>& w, int lvl);
505
506 virtual void apply_fwd(const DistSubLeaf<scalar_t>& B,
507 WorkApplyMPI<scalar_t>& w,
508 bool isroot, long long int flops) const;
509 virtual void apply_bwd(const DistSubLeaf<scalar_t>& B, scalar_t beta,
510 DistSubLeaf<scalar_t>& C,
511 WorkApplyMPI<scalar_t>& w,
512 bool isroot, long long int flops) const;
513 virtual void applyT_fwd(const DistSubLeaf<scalar_t>& B,
514 WorkApplyMPI<scalar_t>& w,
515 bool isroot, long long int flops) const;
516 virtual void applyT_bwd(const DistSubLeaf<scalar_t>& B, scalar_t beta,
517 DistSubLeaf<scalar_t>& C,
518 WorkApplyMPI<scalar_t>& w,
519 bool isroot, long long int flops) const;
520
521 virtual void factor_recursive(WorkFactorMPI<scalar_t>& w,
522 const BLACSGrid* lg, bool isroot,
523 bool partial);
524
525 virtual void solve_fwd(const DistSubLeaf<scalar_t>& b,
526 WorkSolveMPI<scalar_t>& w,
527 bool partial, bool isroot) const;
528 virtual void solve_bwd(DistSubLeaf<scalar_t>& x,
529 WorkSolveMPI<scalar_t>& w, bool isroot) const;
530
531 virtual void extract_fwd(WorkExtractMPI<scalar_t>& w,
532 const BLACSGrid* lg, bool odiag) const;
533 virtual void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
534 const BLACSGrid* lg,
535 WorkExtractMPI<scalar_t>& w) const;
536 virtual void extract_fwd(WorkExtractBlocksMPI<scalar_t>& w,
537 const BLACSGrid* lg,
538 std::vector<bool>& odiag) const;
539 virtual void extract_bwd(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
540 const BLACSGrid* lg,
541 WorkExtractBlocksMPI<scalar_t>& w) const;
542
543 virtual void apply_UV_big(DistSubLeaf<scalar_t>& Theta, DistM_t& Uop,
544 DistSubLeaf<scalar_t>& Phi, DistM_t& Vop,
545 long long int& flops) const;
546
547 virtual void
548 redistribute_to_tree_to_buffers(const DistM_t& A,
549 std::size_t Arlo, std::size_t Aclo,
550 std::vector<std::vector<scalar_t>>& sbuf,
551 int dest);
552 virtual void
553 redistribute_to_tree_from_buffers(const DistM_t& A,
554 std::size_t Arlo, std::size_t Aclo,
555 std::vector<scalar_t*>& pbuf);
556 virtual void delete_redistributed_input();
557
558 friend class HSSMatrixMPI<scalar_t>;
559#endif //defined(STRUMPACK_USE_MPI)
560 };
561
562 } // end namespace HSS
563} // end namespace strumpack
564
565
566#endif // HSS_MATRIX_BASE_HPP
Contains the DenseMatrix and DenseMatrixWrapper classes, simple wrappers around BLAS/LAPACK style den...
Contains the DistributedMatrix and DistributedMatrixWrapper classes, wrappers around ScaLAPACK/PBLAS ...
This file contains the HSSMatrixMPI class definition as well as implementations for a number of it's ...
Contains the HSSOptions class as well as general routines for HSS options.
Contains the structured matrix interfaces.
This is a small wrapper class around a BLACS grid and a BLACS context.
Definition: BLACSGrid.hpp:66
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition: DenseMatrix.hpp:1015
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition: DenseMatrix.hpp:138
Definition: DistributedMatrix.hpp:733
2D block cyclicly distributed matrix, as used by ScaLAPACK.
Definition: DistributedMatrix.hpp:84
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition: HSSMatrixBase.hpp:83
void set_openmp_task_depth(int depth)
Definition: HSSMatrixBase.hpp:273
bool active() const
Definition: HSSMatrixBase.hpp:239
HSSMatrixBase< scalar_t > & operator=(const HSSMatrixBase< scalar_t > &other)
HSSMatrixBase(HSSMatrixBase &&h)=default
HSSMatrixBase(const HSSMatrixBase< scalar_t > &other)
virtual std::size_t levels() const =0
virtual void draw(std::ostream &of, std::size_t rlo, std::size_t clo) const
Definition: HSSMatrixBase.hpp:299
std::size_t rows() const override
Definition: HSSMatrixBase.hpp:163
virtual void shift(scalar_t sigma) override=0
std::pair< std::size_t, std::size_t > dims() const
Definition: HSSMatrixBase.hpp:155
HSSMatrixBase(std::size_t m, std::size_t n, bool active)
virtual ~HSSMatrixBase()=default
bool is_compressed() const
Definition: HSSMatrixBase.hpp:213
virtual void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const =0
bool leaf() const
Definition: HSSMatrixBase.hpp:175
const HSSMatrixBase< scalar_t > & child(int c) const
Definition: HSSMatrixBase.hpp:188
HSSMatrixBase< scalar_t > & child(int c)
Definition: HSSMatrixBase.hpp:201
virtual std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const =0
HSSMatrixBase & operator=(HSSMatrixBase &&h)=default
std::size_t cols() const override
Definition: HSSMatrixBase.hpp:169
bool is_untouched() const
Definition: HSSMatrixBase.hpp:228
Class containing several options for the HSS code and data-structures.
Definition: HSSOptions.hpp:152
Class to represent a structured matrix. This is the abstract base class for several types of structur...
Definition: StructuredMatrix.hpp:209
State
Definition: HSSExtra.hpp:46
Definition: StrumpackOptions.hpp:43