Loading...
Searching...
No Matches
HSSMatrixBase.hpp
Go to the documentation of this file.
1/*
2 * STRUMPACK -- STRUctured Matrices PACKage, Copyright (c) 2014, The
3 * Regents of the University of California, through Lawrence Berkeley
4 * National Laboratory (subject to receipt of any required approvals
5 * from the U.S. Dept. of Energy). All rights reserved.
6 *
7 * If you have questions about your rights to use or distribute this
8 * software, please contact Berkeley Lab's Technology Transfer
9 * Department at TTD@lbl.gov.
10 *
11 * NOTICE. This software is owned by the U.S. Department of Energy. As
12 * such, the U.S. Government has been granted for itself and others
13 * acting on its behalf a paid-up, nonexclusive, irrevocable,
14 * worldwide license in the Software to reproduce, prepare derivative
15 * works, and perform publicly and display publicly. Beginning five
16 * (5) years after the date permission to assert copyright is obtained
17 * from the U.S. Department of Energy, and subject to any subsequent
18 * five (5) year renewals, the U.S. Government is granted for itself
19 * and others acting on its behalf a paid-up, nonexclusive,
20 * irrevocable, worldwide license in the Software to reproduce,
21 * prepare derivative works, distribute copies to the public, perform
22 * publicly and display publicly, and to permit others to do so.
23 *
24 * Developers: Pieter Ghysels, Francois-Henry Rouet, Xiaoye S. Li.
25 * (Lawrence Berkeley National Lab, Computational Research
26 * Division).
27 *
28 */
34#ifndef HSS_MATRIX_BASE_HPP
35#define HSS_MATRIX_BASE_HPP
36
37#include <cassert>
38#include <iostream>
39#include <fstream>
40#include <string>
41#include <vector>
42#include <functional>
43
44#include "dense/DenseMatrix.hpp"
45#include "misc/Triplet.hpp"
46#include "HSSOptions.hpp"
47#include "HSSExtra.hpp"
49#if defined(STRUMPACK_USE_MPI)
51#include "HSSExtraMPI.hpp"
52#include "HSSMatrixMPI.hpp"
53#endif
54
55namespace strumpack {
56 namespace HSS {
57
58#ifndef DOXYGEN_SHOULD_SKIP_THIS
59 template<typename scalar_t> class HSSMatrix;
60#if defined(STRUMPACK_USE_MPI)
61 template<typename scalar_t> class HSSMatrixMPI;
62 template<typename scalar_t> class DistSubLeaf;
63 template<typename scalar_t> class DistSamples;
64#endif //defined(STRUMPACK_USE_MPI)
65#endif //DOXYGEN_SHOULD_SKIP_THIS
66
67
82 template<typename scalar_t> class HSSMatrixBase
83 : public structured::StructuredMatrix<scalar_t> {
84 using real_t = typename RealType<scalar_t>::value_type;
87 using elem_t = typename std::function
88 <void(const std::vector<std::size_t>& I,
89 const std::vector<std::size_t>& J, DenseM_t& B)>;
91#if defined(STRUMPACK_USE_MPI)
94 using delem_t = typename std::function
95 <void(const std::vector<std::size_t>& I,
96 const std::vector<std::size_t>& J, DistM_t& B)>;
97#endif //defined(STRUMPACK_USE_MPI)
98
99 public:
108 HSSMatrixBase(std::size_t m, std::size_t n, bool active);
109
113 virtual ~HSSMatrixBase() = default;
114
120
127
133
140
147 virtual std::unique_ptr<HSSMatrixBase<scalar_t>> clone() const = 0;
148
155 std::pair<std::size_t,std::size_t> dims() const {
156 return std::make_pair(rows_, cols_);
157 }
158
163 std::size_t rows() const override { return rows_; }
164
169 std::size_t cols() const override { return cols_; }
170
175 bool leaf() const { return ch_.empty(); }
176
177 virtual std::size_t factor_nonzeros() const;
178
188 const HSSMatrixBase<scalar_t>& child(int c) const {
189 assert(c>=0 && c<int(ch_.size())); return *(ch_[c]);
190 }
191
202 assert(c>=0 && c<int(ch_.size())); return *(ch_[c]);
203 }
204
213 bool is_compressed() const {
214 return U_state_ == State::COMPRESSED &&
215 V_state_ == State::COMPRESSED;
216 }
217
228 bool is_untouched() const {
229 return U_state_ == State::UNTOUCHED &&
230 V_state_ == State::UNTOUCHED;
231 }
232
239 bool active() const { return active_; }
240
246 virtual std::size_t levels() const = 0;
247
260 virtual void print_info(std::ostream &out=std::cout,
261 std::size_t roff=0,
262 std::size_t coff=0) const = 0;
263
273 void set_openmp_task_depth(int depth) { openmp_task_depth_ = depth; }
274
275#ifndef DOXYGEN_SHOULD_SKIP_THIS
276 virtual void delete_trailing_block() { if (ch_.size()==2) ch_.resize(1); }
277 virtual void reset() {
278 U_state_ = V_state_ = State::UNTOUCHED;
279 U_rank_ = U_rows_ = V_rank_ = V_rows_ = 0;
280 for (auto& c : ch_) c->reset();
281 }
282#endif
283
291 virtual void shift(scalar_t sigma) override = 0;
292
299 virtual void draw(std::ostream& of,
300 std::size_t rlo,
301 std::size_t clo) const {}
302
303#if defined(STRUMPACK_USE_MPI)
304 virtual void forward_solve(WorkSolveMPI<scalar_t>& w,
305 const DistM_t& b, bool partial) const;
306 virtual void backward_solve(WorkSolveMPI<scalar_t>& w,
307 DistM_t& x) const;
308
309 virtual const BLACSGrid* grid() const { return nullptr; }
310 virtual const BLACSGrid* grid(const BLACSGrid* local_grid) const {
311 return active() ? local_grid : nullptr;
312 }
313 virtual const BLACSGrid* grid_local() const { return nullptr; }
314 virtual int Ptotal() const { return 1; }
315 virtual int Pactive() const { return 1; }
316
317 virtual void to_block_row(const DistM_t& A, DenseM_t& sub_A,
318 DistM_t& leaf_A) const;
319 virtual void allocate_block_row(int d, DenseM_t& sub_A,
320 DistM_t& leaf_A) const;
321 virtual void from_block_row(DistM_t& A, const DenseM_t& sub_A,
322 const DistM_t& leaf_A,
323 const BLACSGrid* lg) const;
324#endif //defined(STRUMPACK_USE_MPI)
325
326 protected:
327 std::size_t rows_, cols_;
328
329 // TODO store children array in the sub-class???
330 std::vector<std::unique_ptr<HSSMatrixBase<scalar_t>>> ch_;
331 State U_state_, V_state_;
332 int openmp_task_depth_;
333 bool active_;
334
335 int U_rank_ = 0, U_rows_ = 0, V_rank_ = 0, V_rows_ = 0;
336
337 // Used to redistribute the original 2D block cyclic matrix
338 // according to the HSS tree
339 DenseM_t Asub_;
340
341 HSSFactors<scalar_t> ULV_;
342#if defined(STRUMPACK_USE_MPI)
343 HSSFactorsMPI<scalar_t> ULV_mpi_;
344#endif //defined(STRUMPACK_USE_MPI)
345
346 virtual std::size_t U_rank() const { return U_rank_; }
347 virtual std::size_t V_rank() const { return V_rank_; }
348 virtual std::size_t U_rows() const { return U_rows_; }
349 virtual std::size_t V_rows() const { return V_rows_; }
350
351 virtual void
352 compress_recursive_original(DenseM_t& Rr, DenseM_t& Rc,
353 DenseM_t& Sr, DenseM_t& Sc,
354 const elem_t& Aelem, const opts_t& opts,
355 WorkCompress<scalar_t>& w,
356 int dd, int depth) {}
357 virtual void
358 compress_recursive_stable(DenseM_t& Rr, DenseM_t& Rc,
359 DenseM_t& Sr, DenseM_t& Sc,
360 const elem_t& Aelem, const opts_t& opts,
361 WorkCompress<scalar_t>& w,
362 int d, int dd, int depth) {}
363 virtual void
364 compress_level_original(DenseM_t& Rr, DenseM_t& Rc,
365 DenseM_t& Sr, DenseM_t& Sc,
366 const opts_t& opts, WorkCompress<scalar_t>& w,
367 int dd, int lvl, int depth) {}
368 virtual void
369 compress_level_stable(DenseM_t& Rr, DenseM_t& Rc,
370 DenseM_t& Sr, DenseM_t& Sc,
371 const opts_t& opts, WorkCompress<scalar_t>& w,
372 int d, int dd, int lvl, int depth) {}
373 virtual void
374 compress_recursive_ann(DenseMatrix<std::uint32_t>& ann,
375 DenseMatrix<real_t>& scores,
376 const elem_t& Aelem, const opts_t& opts,
377 WorkCompressANN<scalar_t>& w, int depth) {}
378
379 virtual void
380 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
381 std::vector<std::vector<std::size_t>>& J,
382 const std::pair<std::size_t,std::size_t>& off,
383 WorkCompress<scalar_t>& w,
384 int& self, int lvl) {}
385
386 virtual void
387 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
388 std::vector<std::vector<std::size_t>>& J,
389 std::vector<DenseM_t*>& B,
390 const std::pair<std::size_t,std::size_t>& off,
391 WorkCompress<scalar_t>& w,
392 int& self, int lvl) {}
393 virtual void extract_D_B(const elem_t& Aelem, const opts_t& opts,
394 WorkCompress<scalar_t>& w, int lvl) {}
395
396 virtual void factor_recursive(WorkFactor<scalar_t>& w,
397 bool isroot, bool partial,
398 int depth) {}
399
400 virtual void apply_fwd(const DenseM_t& b, WorkApply<scalar_t>& w,
401 bool isroot, int depth,
402 std::atomic<long long int>& flops) const {}
403 virtual void apply_bwd(const DenseM_t& b, scalar_t beta,
404 DenseM_t& c, WorkApply<scalar_t>& w,
405 bool isroot, int depth,
406 std::atomic<long long int>& flops) const {}
407 virtual void applyT_fwd(const DenseM_t& b, WorkApply<scalar_t>& w,
408 bool isroot, int depth,
409 std::atomic<long long int>& flops) const {}
410 virtual void applyT_bwd(const DenseM_t& b, scalar_t beta,
411 DenseM_t& c, WorkApply<scalar_t>& w,
412 bool isroot, int depth,
413 std::atomic<long long int>& flops) const {}
414
415 virtual void forward_solve(WorkSolve<scalar_t>& w,
416 const DenseMatrix<scalar_t>& b,
417 bool partial) const {}
418 virtual void backward_solve(WorkSolve<scalar_t>& w,
419 DenseM_t& b) const {}
420 virtual void solve_fwd(const DenseM_t& b,
421 WorkSolve<scalar_t>& w, bool partial,
422 bool isroot, int depth) const {}
423 virtual void solve_bwd(DenseM_t& x, WorkSolve<scalar_t>& w,
424 bool isroot, int depth) const {}
425
426 virtual void extract_fwd(WorkExtract<scalar_t>& w,
427 bool odiag, int depth) const {}
428 virtual void extract_bwd(DenseM_t& B, WorkExtract<scalar_t>& w,
429 int depth) const {}
430 virtual void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
431 WorkExtract<scalar_t>& w, int depth) const {}
432
433 virtual void apply_UV_big(DenseM_t& Theta, DenseM_t& Uop,
434 DenseM_t& Phi, DenseM_t& Vop,
435 const std::pair<std::size_t,std::size_t>& offset,
436 int depth,
437 std::atomic<long long int>& flops) const {}
438 virtual void apply_UtVt_big(const DenseM_t& A, DenseM_t& UtA,
439 DenseM_t& VtA,
440 const std::pair<std::size_t, std::size_t>& offset,
441 int depth,
442 std::atomic<long long int>& flops) const {}
443
444 virtual void dense_recursive(DenseM_t& A, WorkDense<scalar_t>& w,
445 bool isroot, int depth) const {}
446
447 virtual void read(std::ifstream& os) {
448 std::cerr << "ERROR read_HSS_node not implemented" << std::endl;
449 }
450 virtual void write(std::ofstream& os) const {
451 std::cerr << "ERROR write_HSS_node not implemented" << std::endl;
452 }
453
454 friend class HSSMatrix<scalar_t>;
455
456#if defined(STRUMPACK_USE_MPI)
457 using delemw_t = typename std::function
458 <void(const std::vector<std::size_t>& I,
459 const std::vector<std::size_t>& J,
460 DistM_t& B, DistM_t& A,
461 std::size_t rlo, std::size_t clo,
462 MPI_Comm comm)>;
463
464
465 virtual void
466 compress_recursive_original(DistSamples<scalar_t>& RS,
467 const delemw_t& Aelem,
468 const opts_t& opts,
469 WorkCompressMPI<scalar_t>& w, int dd);
470 virtual void
471 compress_recursive_stable(DistSamples<scalar_t>& RS,
472 const delemw_t& Aelem,
473 const opts_t& opts,
474 WorkCompressMPI<scalar_t>& w, int d, int dd);
475 virtual void
476 compress_level_original(DistSamples<scalar_t>& RS, const opts_t& opts,
477 WorkCompressMPI<scalar_t>& w, int dd, int lvl);
478 virtual void
479 compress_level_stable(DistSamples<scalar_t>& RS, const opts_t& opts,
480 WorkCompressMPI<scalar_t>& w,
481 int d, int dd, int lvl);
482 virtual void
483 compress_recursive_ann(DenseMatrix<std::uint32_t>& ann,
484 DenseMatrix<real_t>& scores,
485 const delemw_t& Aelem,
486 WorkCompressMPIANN<scalar_t>& w,
487 const opts_t& opts, const BLACSGrid* lg);
488
489 virtual void
490 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
491 std::vector<std::vector<std::size_t>>& J,
492 WorkCompressMPI<scalar_t>& w,
493 int& self, int lvl);
494 virtual void
495 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
496 std::vector<std::vector<std::size_t>>& J,
497 std::vector<DistMW_t>& B,
498 const BLACSGrid* lg,
499 WorkCompressMPI<scalar_t>& w,
500 int& self, int lvl);
501 virtual void extract_D_B(const delemw_t& Aelem, const BLACSGrid* lg,
502 const opts_t& opts,
503 WorkCompressMPI<scalar_t>& w, int lvl);
504
505 virtual void apply_fwd(const DistSubLeaf<scalar_t>& B,
506 WorkApplyMPI<scalar_t>& w,
507 bool isroot, long long int flops) const;
508 virtual void apply_bwd(const DistSubLeaf<scalar_t>& B, scalar_t beta,
509 DistSubLeaf<scalar_t>& C,
510 WorkApplyMPI<scalar_t>& w,
511 bool isroot, long long int flops) const;
512 virtual void applyT_fwd(const DistSubLeaf<scalar_t>& B,
513 WorkApplyMPI<scalar_t>& w,
514 bool isroot, long long int flops) const;
515 virtual void applyT_bwd(const DistSubLeaf<scalar_t>& B, scalar_t beta,
516 DistSubLeaf<scalar_t>& C,
517 WorkApplyMPI<scalar_t>& w,
518 bool isroot, long long int flops) const;
519
520 virtual void factor_recursive(WorkFactorMPI<scalar_t>& w,
521 const BLACSGrid* lg, bool isroot,
522 bool partial);
523
524 virtual void solve_fwd(const DistSubLeaf<scalar_t>& b,
525 WorkSolveMPI<scalar_t>& w,
526 bool partial, bool isroot) const;
527 virtual void solve_bwd(DistSubLeaf<scalar_t>& x,
528 WorkSolveMPI<scalar_t>& w, bool isroot) const;
529
530 virtual void extract_fwd(WorkExtractMPI<scalar_t>& w,
531 const BLACSGrid* lg, bool odiag) const;
532 virtual void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
533 const BLACSGrid* lg,
534 WorkExtractMPI<scalar_t>& w) const;
535 virtual void extract_fwd(WorkExtractBlocksMPI<scalar_t>& w,
536 const BLACSGrid* lg,
537 std::vector<bool>& odiag) const;
538 virtual void extract_bwd(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
539 const BLACSGrid* lg,
540 WorkExtractBlocksMPI<scalar_t>& w) const;
541
542 virtual void apply_UV_big(DistSubLeaf<scalar_t>& Theta, DistM_t& Uop,
543 DistSubLeaf<scalar_t>& Phi, DistM_t& Vop,
544 long long int& flops) const;
545
546 virtual void
547 redistribute_to_tree_to_buffers(const DistM_t& A,
548 std::size_t Arlo, std::size_t Aclo,
549 std::vector<std::vector<scalar_t>>& sbuf,
550 int dest);
551 virtual void
552 redistribute_to_tree_from_buffers(const DistM_t& A,
553 std::size_t Arlo, std::size_t Aclo,
554 std::vector<scalar_t*>& pbuf);
555 virtual void delete_redistributed_input();
556
557 friend class HSSMatrixMPI<scalar_t>;
558#endif //defined(STRUMPACK_USE_MPI)
559 };
560
561 } // end namespace HSS
562} // end namespace strumpack
563
564
565#endif // HSS_MATRIX_BASE_HPP
Contains the DenseMatrix and DenseMatrixWrapper classes, simple wrappers around BLAS/LAPACK style den...
Contains the DistributedMatrix and DistributedMatrixWrapper classes, wrappers around ScaLAPACK/PBLAS ...
This file contains the HSSMatrixMPI class definition as well as implementations for a number of it's ...
Contains the HSSOptions class as well as general routines for HSS options.
Contains the structured matrix interfaces.
This is a small wrapper class around a BLACS grid and a BLACS context.
Definition BLACSGrid.hpp:66
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition DenseMatrix.hpp:1018
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition DenseMatrix.hpp:139
Definition DistributedMatrix.hpp:737
2D block cyclicly distributed matrix, as used by ScaLAPACK.
Definition DistributedMatrix.hpp:84
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition HSSMatrixBase.hpp:83
void set_openmp_task_depth(int depth)
Definition HSSMatrixBase.hpp:273
bool active() const
Definition HSSMatrixBase.hpp:239
HSSMatrixBase< scalar_t > & operator=(const HSSMatrixBase< scalar_t > &other)
HSSMatrixBase(HSSMatrixBase &&h)=default
HSSMatrixBase(const HSSMatrixBase< scalar_t > &other)
virtual std::size_t levels() const =0
virtual void draw(std::ostream &of, std::size_t rlo, std::size_t clo) const
Definition HSSMatrixBase.hpp:299
std::size_t rows() const override
Definition HSSMatrixBase.hpp:163
virtual void shift(scalar_t sigma) override=0
std::pair< std::size_t, std::size_t > dims() const
Definition HSSMatrixBase.hpp:155
HSSMatrixBase(std::size_t m, std::size_t n, bool active)
virtual ~HSSMatrixBase()=default
bool is_compressed() const
Definition HSSMatrixBase.hpp:213
virtual void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const =0
bool leaf() const
Definition HSSMatrixBase.hpp:175
const HSSMatrixBase< scalar_t > & child(int c) const
Definition HSSMatrixBase.hpp:188
HSSMatrixBase< scalar_t > & child(int c)
Definition HSSMatrixBase.hpp:201
virtual std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const =0
HSSMatrixBase & operator=(HSSMatrixBase &&h)=default
std::size_t cols() const override
Definition HSSMatrixBase.hpp:169
bool is_untouched() const
Definition HSSMatrixBase.hpp:228
Class containing several options for the HSS code and data-structures.
Definition HSSOptions.hpp:152
Class to represent a structured matrix. This is the abstract base class for several types of structur...
Definition StructuredMatrix.hpp:209
State
Definition HSSExtra.hpp:46
Definition StrumpackOptions.hpp:44