84 using real_t =
typename RealType<scalar_t>::value_type;
87 using elem_t =
typename std::function
88 <void(
const std::vector<std::size_t>& I,
89 const std::vector<std::size_t>& J,
DenseM_t& B)>;
91#if defined(STRUMPACK_USE_MPI)
94 using delem_t =
typename std::function
95 <void(
const std::vector<std::size_t>& I,
96 const std::vector<std::size_t>& J,
DistM_t& B)>;
147 virtual std::unique_ptr<HSSMatrixBase<scalar_t>>
clone()
const = 0;
155 std::pair<std::size_t,std::size_t>
dims()
const {
156 return std::make_pair(rows_, cols_);
163 std::size_t
rows()
const override {
return rows_; }
169 std::size_t
cols()
const override {
return cols_; }
175 bool leaf()
const {
return ch_.empty(); }
177 virtual std::size_t factor_nonzeros()
const;
189 assert(c>=0 && c<
int(ch_.size()));
return *(ch_[c]);
202 assert(c>=0 && c<
int(ch_.size()));
return *(ch_[c]);
262 std::size_t coff=0)
const = 0;
275#ifndef DOXYGEN_SHOULD_SKIP_THIS
276 virtual void delete_trailing_block() {
if (ch_.size()==2) ch_.resize(1); }
277 virtual void reset() {
279 U_rank_ = U_rows_ = V_rank_ = V_rows_ = 0;
280 for (
auto& c : ch_) c->reset();
291 virtual void shift(scalar_t sigma)
override = 0;
299 virtual void draw(std::ostream& of,
301 std::size_t clo)
const {}
303#if defined(STRUMPACK_USE_MPI)
304 virtual void forward_solve(WorkSolveMPI<scalar_t>& w,
305 const DistM_t& b,
bool partial)
const;
306 virtual void backward_solve(WorkSolveMPI<scalar_t>& w,
309 virtual const BLACSGrid* grid()
const {
return nullptr; }
311 return active() ? local_grid :
nullptr;
313 virtual const BLACSGrid* grid_local()
const {
return nullptr; }
314 virtual int Ptotal()
const {
return 1; }
315 virtual int Pactive()
const {
return 1; }
317 virtual void to_block_row(
const DistM_t& A, DenseM_t& sub_A,
318 DistM_t& leaf_A)
const;
319 virtual void allocate_block_row(
int d, DenseM_t& sub_A,
320 DistM_t& leaf_A)
const;
321 virtual void from_block_row(DistM_t& A,
const DenseM_t& sub_A,
322 const DistM_t& leaf_A,
323 const BLACSGrid* lg)
const;
327 std::size_t rows_, cols_;
330 std::vector<std::unique_ptr<HSSMatrixBase<scalar_t>>> ch_;
331 State U_state_, V_state_;
332 int openmp_task_depth_;
335 int U_rank_ = 0, U_rows_ = 0, V_rank_ = 0, V_rows_ = 0;
341 HSSFactors<scalar_t> ULV_;
342#if defined(STRUMPACK_USE_MPI)
343 HSSFactorsMPI<scalar_t> ULV_mpi_;
346 virtual std::size_t U_rank()
const {
return U_rank_; }
347 virtual std::size_t V_rank()
const {
return V_rank_; }
348 virtual std::size_t U_rows()
const {
return U_rows_; }
349 virtual std::size_t V_rows()
const {
return V_rows_; }
352 compress_recursive_original(DenseM_t& Rr, DenseM_t& Rc,
353 DenseM_t& Sr, DenseM_t& Sc,
354 const elem_t& Aelem,
const opts_t& opts,
355 WorkCompress<scalar_t>& w,
356 int dd,
int depth) {}
358 compress_recursive_stable(DenseM_t& Rr, DenseM_t& Rc,
359 DenseM_t& Sr, DenseM_t& Sc,
360 const elem_t& Aelem,
const opts_t& opts,
361 WorkCompress<scalar_t>& w,
362 int d,
int dd,
int depth) {}
364 compress_level_original(DenseM_t& Rr, DenseM_t& Rc,
365 DenseM_t& Sr, DenseM_t& Sc,
366 const opts_t& opts, WorkCompress<scalar_t>& w,
367 int dd,
int lvl,
int depth) {}
369 compress_level_stable(DenseM_t& Rr, DenseM_t& Rc,
370 DenseM_t& Sr, DenseM_t& Sc,
371 const opts_t& opts, WorkCompress<scalar_t>& w,
372 int d,
int dd,
int lvl,
int depth) {}
374 compress_recursive_ann(DenseMatrix<std::uint32_t>& ann,
375 DenseMatrix<real_t>& scores,
376 const elem_t& Aelem,
const opts_t& opts,
377 WorkCompressANN<scalar_t>& w,
int depth) {}
380 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
381 std::vector<std::vector<std::size_t>>& J,
382 const std::pair<std::size_t,std::size_t>& off,
383 WorkCompress<scalar_t>& w,
384 int& self,
int lvl) {}
387 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
388 std::vector<std::vector<std::size_t>>& J,
389 std::vector<DenseM_t*>& B,
390 const std::pair<std::size_t,std::size_t>& off,
391 WorkCompress<scalar_t>& w,
392 int& self,
int lvl) {}
393 virtual void extract_D_B(
const elem_t& Aelem,
const opts_t& opts,
394 WorkCompress<scalar_t>& w,
int lvl) {}
396 virtual void factor_recursive(WorkFactor<scalar_t>& w,
397 bool isroot,
bool partial,
400 virtual void apply_fwd(
const DenseM_t& b, WorkApply<scalar_t>& w,
401 bool isroot,
int depth,
402 std::atomic<long long int>& flops)
const {}
403 virtual void apply_bwd(
const DenseM_t& b, scalar_t beta,
404 DenseM_t& c, WorkApply<scalar_t>& w,
405 bool isroot,
int depth,
406 std::atomic<long long int>& flops)
const {}
407 virtual void applyT_fwd(
const DenseM_t& b, WorkApply<scalar_t>& w,
408 bool isroot,
int depth,
409 std::atomic<long long int>& flops)
const {}
410 virtual void applyT_bwd(
const DenseM_t& b, scalar_t beta,
411 DenseM_t& c, WorkApply<scalar_t>& w,
412 bool isroot,
int depth,
413 std::atomic<long long int>& flops)
const {}
415 virtual void forward_solve(WorkSolve<scalar_t>& w,
416 const DenseMatrix<scalar_t>& b,
417 bool partial)
const {}
418 virtual void backward_solve(WorkSolve<scalar_t>& w,
419 DenseM_t& b)
const {}
420 virtual void solve_fwd(
const DenseM_t& b,
421 WorkSolve<scalar_t>& w,
bool partial,
422 bool isroot,
int depth)
const {}
423 virtual void solve_bwd(DenseM_t& x, WorkSolve<scalar_t>& w,
424 bool isroot,
int depth)
const {}
426 virtual void extract_fwd(WorkExtract<scalar_t>& w,
427 bool odiag,
int depth)
const {}
428 virtual void extract_bwd(DenseM_t& B, WorkExtract<scalar_t>& w,
430 virtual void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
431 WorkExtract<scalar_t>& w,
int depth)
const {}
433 virtual void apply_UV_big(DenseM_t& Theta, DenseM_t& Uop,
434 DenseM_t& Phi, DenseM_t& Vop,
435 const std::pair<std::size_t,std::size_t>& offset,
437 std::atomic<long long int>& flops)
const {}
438 virtual void apply_UtVt_big(
const DenseM_t& A, DenseM_t& UtA,
440 const std::pair<std::size_t, std::size_t>& offset,
442 std::atomic<long long int>& flops)
const {}
444 virtual void dense_recursive(DenseM_t& A, WorkDense<scalar_t>& w,
445 bool isroot,
int depth)
const {}
447 virtual void read(std::ifstream& os) {
448 std::cerr <<
"ERROR read_HSS_node not implemented" << std::endl;
450 virtual void write(std::ofstream& os)
const {
451 std::cerr <<
"ERROR write_HSS_node not implemented" << std::endl;
454 friend class HSSMatrix<scalar_t>;
456#if defined(STRUMPACK_USE_MPI)
457 using delemw_t =
typename std::function
458 <void(
const std::vector<std::size_t>& I,
459 const std::vector<std::size_t>& J,
460 DistM_t& B, DistM_t& A,
461 std::size_t rlo, std::size_t clo,
466 compress_recursive_original(DistSamples<scalar_t>& RS,
467 const delemw_t& Aelem,
469 WorkCompressMPI<scalar_t>& w,
int dd);
471 compress_recursive_stable(DistSamples<scalar_t>& RS,
472 const delemw_t& Aelem,
474 WorkCompressMPI<scalar_t>& w,
int d,
int dd);
476 compress_level_original(DistSamples<scalar_t>& RS,
const opts_t& opts,
477 WorkCompressMPI<scalar_t>& w,
int dd,
int lvl);
479 compress_level_stable(DistSamples<scalar_t>& RS,
const opts_t& opts,
480 WorkCompressMPI<scalar_t>& w,
481 int d,
int dd,
int lvl);
483 compress_recursive_ann(DenseMatrix<std::uint32_t>& ann,
484 DenseMatrix<real_t>& scores,
485 const delemw_t& Aelem,
486 WorkCompressMPIANN<scalar_t>& w,
487 const opts_t& opts,
const BLACSGrid* lg);
490 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
491 std::vector<std::vector<std::size_t>>& J,
492 WorkCompressMPI<scalar_t>& w,
495 get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
496 std::vector<std::vector<std::size_t>>& J,
497 std::vector<DistMW_t>& B,
499 WorkCompressMPI<scalar_t>& w,
501 virtual void extract_D_B(
const delemw_t& Aelem,
const BLACSGrid* lg,
503 WorkCompressMPI<scalar_t>& w,
int lvl);
505 virtual void apply_fwd(
const DistSubLeaf<scalar_t>& B,
506 WorkApplyMPI<scalar_t>& w,
507 bool isroot,
long long int flops)
const;
508 virtual void apply_bwd(
const DistSubLeaf<scalar_t>& B, scalar_t beta,
509 DistSubLeaf<scalar_t>&
C,
510 WorkApplyMPI<scalar_t>& w,
511 bool isroot,
long long int flops)
const;
512 virtual void applyT_fwd(
const DistSubLeaf<scalar_t>& B,
513 WorkApplyMPI<scalar_t>& w,
514 bool isroot,
long long int flops)
const;
515 virtual void applyT_bwd(
const DistSubLeaf<scalar_t>& B, scalar_t beta,
516 DistSubLeaf<scalar_t>&
C,
517 WorkApplyMPI<scalar_t>& w,
518 bool isroot,
long long int flops)
const;
520 virtual void factor_recursive(WorkFactorMPI<scalar_t>& w,
521 const BLACSGrid* lg,
bool isroot,
524 virtual void solve_fwd(
const DistSubLeaf<scalar_t>& b,
525 WorkSolveMPI<scalar_t>& w,
526 bool partial,
bool isroot)
const;
527 virtual void solve_bwd(DistSubLeaf<scalar_t>& x,
528 WorkSolveMPI<scalar_t>& w,
bool isroot)
const;
530 virtual void extract_fwd(WorkExtractMPI<scalar_t>& w,
531 const BLACSGrid* lg,
bool odiag)
const;
532 virtual void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
534 WorkExtractMPI<scalar_t>& w)
const;
535 virtual void extract_fwd(WorkExtractBlocksMPI<scalar_t>& w,
537 std::vector<bool>& odiag)
const;
538 virtual void extract_bwd(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
540 WorkExtractBlocksMPI<scalar_t>& w)
const;
542 virtual void apply_UV_big(DistSubLeaf<scalar_t>& Theta, DistM_t& Uop,
543 DistSubLeaf<scalar_t>& Phi, DistM_t& Vop,
544 long long int& flops)
const;
547 redistribute_to_tree_to_buffers(
const DistM_t& A,
548 std::size_t Arlo, std::size_t Aclo,
549 std::vector<std::vector<scalar_t>>& sbuf,
552 redistribute_to_tree_from_buffers(
const DistM_t& A,
553 std::size_t Arlo, std::size_t Aclo,
554 std::vector<scalar_t*>& pbuf);
555 virtual void delete_redistributed_input();
557 friend class HSSMatrixMPI<scalar_t>;