34 #ifndef HSS_MATRIX_BASE_HPP
35 #define HSS_MATRIX_BASE_HPP
43 #include "misc/Triplet.hpp"
45 #include "HSSExtra.hpp"
46 #if defined(STRUMPACK_USE_MPI)
47 #include "dense/DistributedMatrix.hpp"
48 #include "HSSExtraMPI.hpp"
49 #include "HSSMatrixMPI.hpp"
55 #ifndef DOXYGEN_SHOULD_SKIP_THIS
56 template<
typename scalar_t>
class HSSMatrix;
57 #if defined(STRUMPACK_USE_MPI)
58 template<
typename scalar_t>
class HSSMatrixMPI;
59 template<
typename scalar_t>
class DistSubLeaf;
60 template<
typename scalar_t>
class DistSamples;
61 #endif //defined(STRUMPACK_USE_MPI)
62 #endif //DOXYGEN_SHOULD_SKIP_THIS
80 using real_t =
typename RealType<scalar_t>::value_type;
83 using elem_t =
typename std::function
84 <void(
const std::vector<std::size_t>& I,
85 const std::vector<std::size_t>& J,
DenseM_t& B)>;
87 #if defined(STRUMPACK_USE_MPI)
90 using delem_t =
typename std::function
91 <void(
const std::vector<std::size_t>& I,
92 const std::vector<std::size_t>& J,
DistM_t& B)>;
93 #endif //defined(STRUMPACK_USE_MPI)
143 virtual std::unique_ptr<HSSMatrixBase<scalar_t>>
clone()
const = 0;
151 std::pair<std::size_t,std::size_t>
dims()
const {
152 return std::make_pair(_rows, _cols);
159 std::size_t
rows()
const {
return _rows; }
165 std::size_t
cols()
const {
return _cols; }
171 bool leaf()
const {
return _ch.empty(); }
183 assert(c>=0 && c<
int(_ch.size()));
return *(_ch[c]);
196 assert(c>=0 && c<
int(_ch.size()));
return *(_ch[c]);
241 virtual std::size_t
rank()
const = 0;
251 virtual std::size_t
memory()
const = 0;
260 virtual std::size_t
nonzeros()
const = 0;
267 virtual std::size_t
levels()
const = 0;
282 (std::ostream &out=std::cout,
283 std::size_t roff=0, std::size_t coff=0)
const = 0;
296 #ifndef DOXYGEN_SHOULD_SKIP_THIS
297 virtual void delete_trailing_block() {
if (_ch.size()==2) _ch.resize(1); }
298 virtual void reset() {
300 _U_rank = _U_rows = _V_rank = _V_rows = 0;
301 for (
auto& c : _ch) c->reset();
312 virtual void shift(scalar_t sigma) = 0;
321 (std::ostream& of, std::size_t rlo, std::size_t clo)
const {};
323 #if defined(STRUMPACK_USE_MPI)
324 virtual void forward_solve
326 const DistM_t& b,
bool partial)
const;
327 virtual void backward_solve
331 virtual const BLACSGrid* grid()
const {
return nullptr; }
333 return active() ? local_grid :
nullptr;
335 virtual const BLACSGrid* grid_local()
const {
return nullptr; }
336 virtual int Ptotal()
const {
return 1; }
337 virtual int Pactive()
const {
return 1; }
339 virtual void to_block_row
340 (
const DistM_t& A, DenseM_t& sub_A, DistM_t& leaf_A)
const;
341 virtual void allocate_block_row
342 (
int d, DenseM_t& sub_A, DistM_t& leaf_A)
const;
343 virtual void from_block_row
344 (DistM_t& A,
const DenseM_t& sub_A,
const DistM_t& leaf_A,
345 const BLACSGrid* lg)
const;
346 #endif //defined(STRUMPACK_USE_MPI)
349 std::size_t _rows, _cols;
352 std::vector<std::unique_ptr<HSSMatrixBase<scalar_t>>> _ch;
353 State _U_state, _V_state;
354 int _openmp_task_depth;
357 int _U_rank = 0, _U_rows = 0, _V_rank = 0, _V_rows = 0;
358 virtual std::size_t U_rank()
const {
return _U_rank; }
359 virtual std::size_t V_rank()
const {
return _V_rank; }
360 virtual std::size_t U_rows()
const {
return _U_rows; };
361 virtual std::size_t V_rows()
const {
return _V_rows; };
367 virtual void compress_recursive_original
368 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
369 const elem_t& Aelem,
const opts_t& opts, WorkCompress<scalar_t>& w,
370 int dd,
int depth) {};
371 virtual void compress_recursive_stable
372 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
373 const elem_t& Aelem,
const opts_t& opts, WorkCompress<scalar_t>& w,
374 int d,
int dd,
int depth) {};
375 virtual void compress_level_original
376 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
377 const opts_t& opts, WorkCompress<scalar_t>& w,
378 int dd,
int lvl,
int depth) {}
379 virtual void compress_level_stable
380 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
381 const opts_t& opts, WorkCompress<scalar_t>& w,
382 int d,
int dd,
int lvl,
int depth) {}
383 virtual void compress_recursive_ann
384 (DenseMatrix<std::uint32_t>& ann, DenseMatrix<real_t>& scores,
385 const elem_t& Aelem,
const opts_t& opts,
386 WorkCompressANN<scalar_t>& w,
int depth) {}
388 virtual void get_extraction_indices
389 (std::vector<std::vector<std::size_t>>& I,
390 std::vector<std::vector<std::size_t>>& J,
391 const std::pair<std::size_t,std::size_t>& off,
392 WorkCompress<scalar_t>& w,
int&
self,
int lvl) {}
394 virtual void get_extraction_indices
395 (std::vector<std::vector<std::size_t>>& I,
396 std::vector<std::vector<std::size_t>>& J,
397 std::vector<DenseM_t*>& B,
398 const std::pair<std::size_t,std::size_t>& off,
399 WorkCompress<scalar_t>& w,
int&
self,
int lvl) {}
400 virtual void extract_D_B
401 (
const elem_t& Aelem,
const opts_t& opts,
402 WorkCompress<scalar_t>& w,
int lvl) {}
404 virtual real_t update_orthogonal_basis
405 (DenseM_t& S,
int d,
int dd,
int depth) {
return real_t(0.); }
407 virtual void factor_recursive
408 (HSSFactors<scalar_t>& ULV, WorkFactor<scalar_t>& w,
409 bool isroot,
bool partial,
int depth)
const {};
411 virtual void apply_fwd
412 (
const DenseM_t& b, WorkApply<scalar_t>& w,
413 bool isroot,
int depth, std::atomic<long long int>& flops)
const {};
414 virtual void apply_bwd
415 (
const DenseM_t& b, scalar_t beta, DenseM_t& c, WorkApply<scalar_t>& w,
416 bool isroot,
int depth, std::atomic<long long int>& flops)
const {};
417 virtual void applyT_fwd
418 (
const DenseM_t& b, WorkApply<scalar_t>& w,
bool isroot,
419 int depth, std::atomic<long long int>& flops)
const {};
420 virtual void applyT_bwd
421 (
const DenseM_t& b, scalar_t beta, DenseM_t& c, WorkApply<scalar_t>& w,
422 bool isroot,
int depth, std::atomic<long long int>& flops)
const {};
424 virtual void forward_solve
425 (
const HSSFactors<scalar_t>& ULV, WorkSolve<scalar_t>& w,
426 const DenseMatrix<scalar_t>& b,
bool partial)
const {};
427 virtual void backward_solve
428 (
const HSSFactors<scalar_t>& ULV, WorkSolve<scalar_t>& w,
429 DenseMatrix<scalar_t>& b)
const {};
430 virtual void solve_fwd
431 (
const HSSFactors<scalar_t>& ULV,
const DenseM_t& b,
432 WorkSolve<scalar_t>& w,
bool partial,
bool isroot,
int depth)
const {};
433 virtual void solve_bwd
434 (
const HSSFactors<scalar_t>& ULV, DenseM_t& x, WorkSolve<scalar_t>& w,
435 bool isroot,
int depth)
const {};
437 virtual void extract_fwd
438 (WorkExtract<scalar_t>& w,
bool odiag,
int depth)
const {};
439 virtual void extract_bwd
440 (DenseMatrix<scalar_t>& B, WorkExtract<scalar_t>& w,
442 virtual void extract_bwd
443 (std::vector<Triplet<scalar_t>>& triplets,
444 WorkExtract<scalar_t>& w,
int depth)
const {};
446 virtual void apply_UV_big
447 (DenseM_t& Theta, DenseM_t& Uop, DenseM_t& Phi, DenseM_t& Vop,
448 const std::pair<std::size_t, std::size_t>& offset,
int depth,
449 std::atomic<long long int>& flops)
const {};
450 virtual void apply_UtVt_big
451 (
const DenseM_t& A, DenseM_t& UtA, DenseM_t& VtA,
452 const std::pair<std::size_t, std::size_t>& offset,
453 int depth, std::atomic<long long int>& flops)
const {};
455 virtual void dense_recursive
456 (DenseM_t& A, WorkDense<scalar_t>& w,
bool isroot,
int depth)
const {};
458 friend class HSSMatrix<scalar_t>;
460 #if defined(STRUMPACK_USE_MPI)
461 using delemw_t =
typename std::function
462 <void(
const std::vector<std::size_t>& I,
463 const std::vector<std::size_t>& J,
464 DistM_t& B, DistM_t& A,
465 std::size_t rlo, std::size_t clo,
468 virtual void compress_recursive_original
469 (DistSamples<scalar_t>& RS,
const delemw_t& Aelem,
470 const opts_t& opts, WorkCompressMPI<scalar_t>& w,
int dd);
471 virtual void compress_recursive_stable
472 (DistSamples<scalar_t>& RS,
const delemw_t& Aelem,
473 const opts_t& opts, WorkCompressMPI<scalar_t>& w,
int d,
int dd);
474 virtual void compress_level_original
475 (DistSamples<scalar_t>& RS,
const opts_t& opts,
476 WorkCompressMPI<scalar_t>& w,
int dd,
int lvl);
477 virtual void compress_level_stable
478 (DistSamples<scalar_t>& RS,
const opts_t& opts,
479 WorkCompressMPI<scalar_t>& w,
int d,
int dd,
int lvl);
480 virtual void compress_recursive_ann
481 (DenseMatrix<std::uint32_t>& ann, DenseMatrix<real_t>& scores,
482 const delemw_t& Aelem, WorkCompressMPIANN<scalar_t>& w,
483 const opts_t& opts,
const BLACSGrid* lg);
485 virtual void get_extraction_indices
486 (std::vector<std::vector<std::size_t>>& I,
487 std::vector<std::vector<std::size_t>>& J,
488 WorkCompressMPI<scalar_t>& w,
int&
self,
int lvl);
489 virtual void get_extraction_indices
490 (std::vector<std::vector<std::size_t>>& I,
491 std::vector<std::vector<std::size_t>>& J, std::vector<DistMW_t>& B,
492 const BLACSGrid* lg, WorkCompressMPI<scalar_t>& w,
int&
self,
int lvl);
493 virtual void extract_D_B
494 (
const delemw_t& Aelem,
const BLACSGrid* lg,
const opts_t& opts,
495 WorkCompressMPI<scalar_t>& w,
int lvl);
497 virtual void apply_fwd
498 (
const DistSubLeaf<scalar_t>& B, WorkApplyMPI<scalar_t>& w,
499 bool isroot,
long long int flops)
const;
500 virtual void apply_bwd
501 (
const DistSubLeaf<scalar_t>& B, scalar_t beta,
502 DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
503 bool isroot,
long long int flops)
const;
504 virtual void applyT_fwd
505 (
const DistSubLeaf<scalar_t>& B, WorkApplyMPI<scalar_t>& w,
506 bool isroot,
long long int flops)
const;
507 virtual void applyT_bwd
508 (
const DistSubLeaf<scalar_t>& B, scalar_t beta,
509 DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
510 bool isroot,
long long int flops)
const;
512 virtual void factor_recursive
513 (HSSFactorsMPI<scalar_t>& ULV, WorkFactorMPI<scalar_t>& w,
514 const BLACSGrid* lg,
bool isroot,
bool partial)
const;
516 virtual void solve_fwd
517 (
const HSSFactorsMPI<scalar_t>& ULV,
const DistSubLeaf<scalar_t>& b,
518 WorkSolveMPI<scalar_t>& w,
bool partial,
bool isroot)
const;
519 virtual void solve_bwd
520 (
const HSSFactorsMPI<scalar_t>& ULV, DistSubLeaf<scalar_t>& x,
521 WorkSolveMPI<scalar_t>& w,
bool isroot)
const;
523 virtual void extract_fwd
524 (WorkExtractMPI<scalar_t>& w,
const BLACSGrid* lg,
bool odiag)
const;
525 virtual void extract_bwd
526 (std::vector<Triplet<scalar_t>>& triplets,
527 const BLACSGrid* lg, WorkExtractMPI<scalar_t>& w)
const;
528 virtual void extract_fwd
529 (WorkExtractBlocksMPI<scalar_t>& w,
const BLACSGrid* lg,
530 std::vector<bool>& odiag)
const;
531 virtual void extract_bwd
532 (std::vector<std::vector<Triplet<scalar_t>>>& triplets,
533 const BLACSGrid* lg, WorkExtractBlocksMPI<scalar_t>& w)
const;
535 virtual void apply_UV_big
536 (DistSubLeaf<scalar_t>& Theta, DistM_t& Uop,
537 DistSubLeaf<scalar_t>& Phi, DistM_t& Vop,
538 long long int& flops)
const;
540 virtual void redistribute_to_tree_to_buffers
541 (
const DistM_t& A, std::size_t Arlo, std::size_t Aclo,
542 std::vector<std::vector<scalar_t>>& sbuf,
int dest);
543 virtual void redistribute_to_tree_from_buffers
544 (
const DistM_t& A, std::size_t Arlo, std::size_t Aclo,
545 std::vector<scalar_t*>& pbuf);
546 virtual void delete_redistributed_input();
548 friend class HSSMatrixMPI<scalar_t>;
549 #endif //defined(STRUMPACK_USE_MPI)
556 #endif // HSS_MATRIX_BASE_HPP