34 #ifndef HSS_MATRIX_BASE_HPP
35 #define HSS_MATRIX_BASE_HPP
45 #include "misc/Triplet.hpp"
47 #include "HSSExtra.hpp"
48 #if defined(STRUMPACK_USE_MPI)
49 #include "dense/DistributedMatrix.hpp"
50 #include "HSSExtraMPI.hpp"
51 #include "HSSMatrixMPI.hpp"
57 #ifndef DOXYGEN_SHOULD_SKIP_THIS
58 template<
typename scalar_t>
class HSSMatrix;
59 #if defined(STRUMPACK_USE_MPI)
60 template<
typename scalar_t>
class HSSMatrixMPI;
61 template<
typename scalar_t>
class DistSubLeaf;
62 template<
typename scalar_t>
class DistSamples;
63 #endif //defined(STRUMPACK_USE_MPI)
64 #endif //DOXYGEN_SHOULD_SKIP_THIS
82 using real_t =
typename RealType<scalar_t>::value_type;
85 using elem_t =
typename std::function
86 <void(
const std::vector<std::size_t>& I,
87 const std::vector<std::size_t>& J,
DenseM_t& B)>;
89 #if defined(STRUMPACK_USE_MPI)
92 using delem_t =
typename std::function
93 <void(
const std::vector<std::size_t>& I,
94 const std::vector<std::size_t>& J,
DistM_t& B)>;
95 #endif //defined(STRUMPACK_USE_MPI)
145 virtual std::unique_ptr<HSSMatrixBase<scalar_t>>
clone()
const = 0;
153 std::pair<std::size_t,std::size_t>
dims()
const {
154 return std::make_pair(_rows, _cols);
161 std::size_t
rows()
const {
return _rows; }
167 std::size_t
cols()
const {
return _cols; }
173 bool leaf()
const {
return _ch.empty(); }
185 assert(c>=0 && c<
int(_ch.size()));
return *(_ch[c]);
198 assert(c>=0 && c<
int(_ch.size()));
return *(_ch[c]);
243 virtual std::size_t
rank()
const = 0;
253 virtual std::size_t
memory()
const = 0;
262 virtual std::size_t
nonzeros()
const = 0;
269 virtual std::size_t
levels()
const = 0;
284 (std::ostream &out=std::cout,
285 std::size_t roff=0, std::size_t coff=0)
const = 0;
298 #ifndef DOXYGEN_SHOULD_SKIP_THIS
299 virtual void delete_trailing_block() {
if (_ch.size()==2) _ch.resize(1); }
300 virtual void reset() {
302 _U_rank = _U_rows = _V_rank = _V_rows = 0;
303 for (
auto& c : _ch) c->reset();
314 virtual void shift(scalar_t sigma) = 0;
323 (std::ostream& of, std::size_t rlo, std::size_t clo)
const {}
325 #if defined(STRUMPACK_USE_MPI)
326 virtual void forward_solve
328 const DistM_t& b,
bool partial)
const;
329 virtual void backward_solve
333 virtual const BLACSGrid* grid()
const {
return nullptr; }
335 return active() ? local_grid :
nullptr;
337 virtual const BLACSGrid* grid_local()
const {
return nullptr; }
338 virtual int Ptotal()
const {
return 1; }
339 virtual int Pactive()
const {
return 1; }
341 virtual void to_block_row
342 (
const DistM_t& A, DenseM_t& sub_A, DistM_t& leaf_A)
const;
343 virtual void allocate_block_row
344 (
int d, DenseM_t& sub_A, DistM_t& leaf_A)
const;
345 virtual void from_block_row
346 (DistM_t& A,
const DenseM_t& sub_A,
const DistM_t& leaf_A,
347 const BLACSGrid* lg)
const;
348 #endif //defined(STRUMPACK_USE_MPI)
351 std::size_t _rows, _cols;
354 std::vector<std::unique_ptr<HSSMatrixBase<scalar_t>>> _ch;
355 State _U_state, _V_state;
356 int _openmp_task_depth;
359 int _U_rank = 0, _U_rows = 0, _V_rank = 0, _V_rows = 0;
360 virtual std::size_t U_rank()
const {
return _U_rank; }
361 virtual std::size_t V_rank()
const {
return _V_rank; }
362 virtual std::size_t U_rows()
const {
return _U_rows; }
363 virtual std::size_t V_rows()
const {
return _V_rows; }
369 virtual void compress_recursive_original
370 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
371 const elem_t& Aelem,
const opts_t& opts, WorkCompress<scalar_t>& w,
372 int dd,
int depth) {}
373 virtual void compress_recursive_stable
374 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
375 const elem_t& Aelem,
const opts_t& opts, WorkCompress<scalar_t>& w,
376 int d,
int dd,
int depth) {}
377 virtual void compress_level_original
378 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
379 const opts_t& opts, WorkCompress<scalar_t>& w,
380 int dd,
int lvl,
int depth) {}
381 virtual void compress_level_stable
382 (DenseM_t& Rr, DenseM_t& Rc, DenseM_t& Sr, DenseM_t& Sc,
383 const opts_t& opts, WorkCompress<scalar_t>& w,
384 int d,
int dd,
int lvl,
int depth) {}
385 virtual void compress_recursive_ann
386 (DenseMatrix<std::uint32_t>& ann, DenseMatrix<real_t>& scores,
387 const elem_t& Aelem,
const opts_t& opts,
388 WorkCompressANN<scalar_t>& w,
int depth) {}
390 virtual void get_extraction_indices
391 (std::vector<std::vector<std::size_t>>& I,
392 std::vector<std::vector<std::size_t>>& J,
393 const std::pair<std::size_t,std::size_t>& off,
394 WorkCompress<scalar_t>& w,
int&
self,
int lvl) {}
396 virtual void get_extraction_indices
397 (std::vector<std::vector<std::size_t>>& I,
398 std::vector<std::vector<std::size_t>>& J,
399 std::vector<DenseM_t*>& B,
400 const std::pair<std::size_t,std::size_t>& off,
401 WorkCompress<scalar_t>& w,
int&
self,
int lvl) {}
402 virtual void extract_D_B
403 (
const elem_t& Aelem,
const opts_t& opts,
404 WorkCompress<scalar_t>& w,
int lvl) {}
406 virtual void factor_recursive
407 (HSSFactors<scalar_t>& ULV, WorkFactor<scalar_t>& w,
408 bool isroot,
bool partial,
int depth)
const {}
410 virtual void apply_fwd
411 (
const DenseM_t& b, WorkApply<scalar_t>& w,
412 bool isroot,
int depth, std::atomic<long long int>& flops)
const {}
413 virtual void apply_bwd
414 (
const DenseM_t& b, scalar_t beta, DenseM_t& c, WorkApply<scalar_t>& w,
415 bool isroot,
int depth, std::atomic<long long int>& flops)
const {}
416 virtual void applyT_fwd
417 (
const DenseM_t& b, WorkApply<scalar_t>& w,
bool isroot,
418 int depth, std::atomic<long long int>& flops)
const {}
419 virtual void applyT_bwd
420 (
const DenseM_t& b, scalar_t beta, DenseM_t& c, WorkApply<scalar_t>& w,
421 bool isroot,
int depth, std::atomic<long long int>& flops)
const {}
423 virtual void forward_solve
424 (
const HSSFactors<scalar_t>& ULV, WorkSolve<scalar_t>& w,
425 const DenseMatrix<scalar_t>& b,
bool partial)
const {}
426 virtual void backward_solve
427 (
const HSSFactors<scalar_t>& ULV, WorkSolve<scalar_t>& w,
428 DenseMatrix<scalar_t>& b)
const {}
429 virtual void solve_fwd
430 (
const HSSFactors<scalar_t>& ULV,
const DenseM_t& b,
431 WorkSolve<scalar_t>& w,
bool partial,
bool isroot,
int depth)
const {}
432 virtual void solve_bwd
433 (
const HSSFactors<scalar_t>& ULV, DenseM_t& x, WorkSolve<scalar_t>& w,
434 bool isroot,
int depth)
const {}
436 virtual void extract_fwd
437 (WorkExtract<scalar_t>& w,
bool odiag,
int depth)
const {}
438 virtual void extract_bwd
439 (DenseMatrix<scalar_t>& B, WorkExtract<scalar_t>& w,
441 virtual void extract_bwd
442 (std::vector<Triplet<scalar_t>>& triplets,
443 WorkExtract<scalar_t>& w,
int depth)
const {}
445 virtual void apply_UV_big
446 (DenseM_t& Theta, DenseM_t& Uop, DenseM_t& Phi, DenseM_t& Vop,
447 const std::pair<std::size_t, std::size_t>& offset,
int depth,
448 std::atomic<long long int>& flops)
const {}
449 virtual void apply_UtVt_big
450 (
const DenseM_t& A, DenseM_t& UtA, DenseM_t& VtA,
451 const std::pair<std::size_t, std::size_t>& offset,
452 int depth, std::atomic<long long int>& flops)
const {}
454 virtual void dense_recursive
455 (DenseM_t& A, WorkDense<scalar_t>& w,
bool isroot,
int depth)
const {}
457 virtual void read(std::ifstream& os) {
458 std::cerr <<
"ERROR read_HSS_node not implemented" << std::endl;
460 virtual void write(std::ofstream& os)
const {
461 std::cerr <<
"ERROR write_HSS_node not implemented" << std::endl;
464 friend class HSSMatrix<scalar_t>;
466 #if defined(STRUMPACK_USE_MPI)
467 using delemw_t =
typename std::function
468 <void(
const std::vector<std::size_t>& I,
469 const std::vector<std::size_t>& J,
470 DistM_t& B, DistM_t& A,
471 std::size_t rlo, std::size_t clo,
474 virtual void compress_recursive_original
475 (DistSamples<scalar_t>& RS,
const delemw_t& Aelem,
476 const opts_t& opts, WorkCompressMPI<scalar_t>& w,
int dd);
477 virtual void compress_recursive_stable
478 (DistSamples<scalar_t>& RS,
const delemw_t& Aelem,
479 const opts_t& opts, WorkCompressMPI<scalar_t>& w,
int d,
int dd);
480 virtual void compress_level_original
481 (DistSamples<scalar_t>& RS,
const opts_t& opts,
482 WorkCompressMPI<scalar_t>& w,
int dd,
int lvl);
483 virtual void compress_level_stable
484 (DistSamples<scalar_t>& RS,
const opts_t& opts,
485 WorkCompressMPI<scalar_t>& w,
int d,
int dd,
int lvl);
486 virtual void compress_recursive_ann
487 (DenseMatrix<std::uint32_t>& ann, DenseMatrix<real_t>& scores,
488 const delemw_t& Aelem, WorkCompressMPIANN<scalar_t>& w,
489 const opts_t& opts,
const BLACSGrid* lg);
491 virtual void get_extraction_indices
492 (std::vector<std::vector<std::size_t>>& I,
493 std::vector<std::vector<std::size_t>>& J,
494 WorkCompressMPI<scalar_t>& w,
int&
self,
int lvl);
495 virtual void get_extraction_indices
496 (std::vector<std::vector<std::size_t>>& I,
497 std::vector<std::vector<std::size_t>>& J, std::vector<DistMW_t>& B,
498 const BLACSGrid* lg, WorkCompressMPI<scalar_t>& w,
int&
self,
int lvl);
499 virtual void extract_D_B
500 (
const delemw_t& Aelem,
const BLACSGrid* lg,
const opts_t& opts,
501 WorkCompressMPI<scalar_t>& w,
int lvl);
503 virtual void apply_fwd
504 (
const DistSubLeaf<scalar_t>& B, WorkApplyMPI<scalar_t>& w,
505 bool isroot,
long long int flops)
const;
506 virtual void apply_bwd
507 (
const DistSubLeaf<scalar_t>& B, scalar_t beta,
508 DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
509 bool isroot,
long long int flops)
const;
510 virtual void applyT_fwd
511 (
const DistSubLeaf<scalar_t>& B, WorkApplyMPI<scalar_t>& w,
512 bool isroot,
long long int flops)
const;
513 virtual void applyT_bwd
514 (
const DistSubLeaf<scalar_t>& B, scalar_t beta,
515 DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
516 bool isroot,
long long int flops)
const;
518 virtual void factor_recursive
519 (HSSFactorsMPI<scalar_t>& ULV, WorkFactorMPI<scalar_t>& w,
520 const BLACSGrid* lg,
bool isroot,
bool partial)
const;
522 virtual void solve_fwd
523 (
const HSSFactorsMPI<scalar_t>& ULV,
const DistSubLeaf<scalar_t>& b,
524 WorkSolveMPI<scalar_t>& w,
bool partial,
bool isroot)
const;
525 virtual void solve_bwd
526 (
const HSSFactorsMPI<scalar_t>& ULV, DistSubLeaf<scalar_t>& x,
527 WorkSolveMPI<scalar_t>& w,
bool isroot)
const;
529 virtual void extract_fwd
530 (WorkExtractMPI<scalar_t>& w,
const BLACSGrid* lg,
bool odiag)
const;
531 virtual void extract_bwd
532 (std::vector<Triplet<scalar_t>>& triplets,
533 const BLACSGrid* lg, WorkExtractMPI<scalar_t>& w)
const;
534 virtual void extract_fwd
535 (WorkExtractBlocksMPI<scalar_t>& w,
const BLACSGrid* lg,
536 std::vector<bool>& odiag)
const;
537 virtual void extract_bwd
538 (std::vector<std::vector<Triplet<scalar_t>>>& triplets,
539 const BLACSGrid* lg, WorkExtractBlocksMPI<scalar_t>& w)
const;
541 virtual void apply_UV_big
542 (DistSubLeaf<scalar_t>& Theta, DistM_t& Uop,
543 DistSubLeaf<scalar_t>& Phi, DistM_t& Vop,
544 long long int& flops)
const;
546 virtual void redistribute_to_tree_to_buffers
547 (
const DistM_t& A, std::size_t Arlo, std::size_t Aclo,
548 std::vector<std::vector<scalar_t>>& sbuf,
int dest);
549 virtual void redistribute_to_tree_from_buffers
550 (
const DistM_t& A, std::size_t Arlo, std::size_t Aclo,
551 std::vector<scalar_t*>& pbuf);
552 virtual void delete_redistributed_input();
554 friend class HSSMatrixMPI<scalar_t>;
555 #endif //defined(STRUMPACK_USE_MPI)
562 #endif // HSS_MATRIX_BASE_HPP