36#ifndef HSS_MATRIX_MPI_HPP
37#define HSS_MATRIX_MPI_HPP
43#include "HSSExtraMPI.hpp"
44#include "DistSamples.hpp"
45#include "DistElemMult.hpp"
46#include "HSSBasisIDMPI.hpp"
69 using real_t =
typename RealType<scalar_t>::value_type;
74 using delem_t =
typename std::function
75 <void(
const std::vector<std::size_t>& I,
76 const std::vector<std::size_t>& J,
DistM_t& B)>;
77 using delem_blocks_t =
typename std::function
78 <void(
const std::vector<std::vector<std::size_t>>& I,
79 const std::vector<std::vector<std::size_t>>& J,
80 std::vector<DistMW_t>& B)>;
81 using elem_t =
typename std::function
82 <void(
const std::vector<std::size_t>& I,
83 const std::vector<std::size_t>& J,
DenseM_t& B)>;
84 using dmult_t =
typename std::function
96 const dmult_t& Amult,
const delem_t& Aelem,
99 const dmult_t& Amult,
const delem_blocks_t& Aelem,
102 const dmult_t& Amult,
const delem_t& Aelem,
112 std::unique_ptr<HSSMatrixBase<scalar_t>>
clone()
const override;
115 return this->ch_[c].get();
119 const BLACSGrid* grid()
const override {
return blacs_grid_; }
120 const BLACSGrid* grid(
const BLACSGrid* grid)
const override {
return blacs_grid_; }
121 const BLACSGrid* grid_local()
const override {
return blacs_grid_local_; }
122 const MPIComm& Comm()
const {
return grid()->
Comm(); }
123 MPI_Comm comm()
const {
return Comm().
comm(); }
124 int Ptotal()
const override {
return grid()->
P(); }
125 int Pactive()
const override {
return grid()->
npactives(); }
128 void compress(
const DistM_t& A,
130 void compress(
const dmult_t& Amult,
131 const delem_t& Aelem,
133 void compress(
const dmult_t& Amult,
134 const delem_blocks_t& Aelem,
136 void compress(
const kernel::Kernel<real_t>& K,
const opts_t& opts);
139 void partial_factor();
141 void forward_solve(WorkSolveMPI<scalar_t>& w,
const DistM_t& b,
142 bool partial)
const override;
143 void backward_solve(WorkSolveMPI<scalar_t>& w,
151 scalar_t get(std::size_t i, std::size_t j)
const;
152 DistM_t extract(
const std::vector<std::size_t>& I,
153 const std::vector<std::size_t>& J,
156 extract(
const std::vector<std::vector<std::size_t>>& I,
157 const std::vector<std::vector<std::size_t>>& J,
159 void extract_add(
const std::vector<std::size_t>& I,
160 const std::vector<std::size_t>& J,
DistM_t& B)
const;
161 void extract_add(
const std::vector<std::vector<std::size_t>>& I,
162 const std::vector<std::vector<std::size_t>>& J,
163 std::vector<DistM_t>& B)
const;
167 void Schur_product_direct(
const DistM_t& Theta,
176 std::size_t max_rank()
const;
177 std::size_t total_memory()
const;
178 std::size_t total_nonzeros()
const;
179 std::size_t total_factor_nonzeros()
const;
180 std::size_t max_levels()
const;
181 std::size_t
rank()
const override;
184 std::size_t factor_nonzeros()
const override;
189 std::size_t coff=0)
const override;
193 void shift(scalar_t sigma)
override;
196 void to_block_row(
const DistM_t& A,
198 DistM_t& leaf_A)
const override;
199 void allocate_block_row(
int d, DenseM_t& sub_A,
200 DistM_t& leaf_A)
const override;
201 void from_block_row(DistM_t& A,
202 const DenseM_t& sub_A,
203 const DistM_t& leaf_A,
206 void delete_trailing_block()
override;
207 void reset()
override;
210 using delemw_t =
typename std::function
211 <void(
const std::vector<std::size_t>& I,
212 const std::vector<std::size_t>& J,
213 DistM_t& B, DistM_t& A,
214 std::size_t rlo, std::size_t clo,
219 std::unique_ptr<const BLACSGrid> owned_blacs_grid_;
220 std::unique_ptr<const BLACSGrid> owned_blacs_grid_local_;
224 HSSBasisIDMPI<scalar_t> U_, V_;
225 DistM_t D_, B01_, B10_;
229 DistM_t A_, A01_, A10_;
231 HSSMatrixMPI(std::size_t m, std::size_t n,
const opts_t& opts,
233 std::size_t roff, std::size_t coff);
236 std::size_t roff, std::size_t coff);
237 void setup_hierarchy(
const opts_t& opts,
238 std::size_t roff, std::size_t coff);
240 std::size_t roff, std::size_t coff);
241 void setup_local_context();
242 void setup_ranges(std::size_t roff, std::size_t coff);
244 void compress_original_nosync(
const dmult_t& Amult,
245 const delemw_t& Aelem,
247 void compress_original_sync(
const dmult_t& Amult,
248 const delemw_t& Aelem,
250 void compress_original_sync(
const dmult_t& Amult,
251 const delem_blocks_t& Aelem,
253 void compress_stable_nosync(
const dmult_t& Amult,
254 const delemw_t& Aelem,
256 void compress_stable_sync(
const dmult_t& Amult,
257 const delemw_t& Aelem,
259 void compress_stable_sync(
const dmult_t& Amult,
260 const delem_blocks_t& Aelem,
262 void compress_hard_restart_nosync(
const dmult_t& Amult,
263 const delemw_t& Aelem,
265 void compress_hard_restart_sync(
const dmult_t& Amult,
266 const delemw_t& Aelem,
268 void compress_hard_restart_sync(
const dmult_t& Amult,
269 const delem_blocks_t& Aelem,
274 const delemw_t& Aelem,
275 WorkCompressMPIANN<scalar_t>& w,
280 WorkCompressMPIANN<scalar_t>& w,
281 const delemw_t& Aelem,
283 bool compute_U_V_bases_ann(DistM_t& S,
const opts_t& opts,
284 WorkCompressMPIANN<scalar_t>& w);
285 void communicate_child_data_ann(WorkCompressMPIANN<scalar_t>& w);
287 void compress_recursive_original(DistSamples<scalar_t>& RS,
288 const delemw_t& Aelem,
290 WorkCompressMPI<scalar_t>& w,
292 void compress_recursive_stable(DistSamples<scalar_t>& RS,
293 const delemw_t& Aelem,
295 WorkCompressMPI<scalar_t>& w,
296 int d,
int dd)
override;
297 void compute_local_samples(
const DistSamples<scalar_t>& RS,
298 WorkCompressMPI<scalar_t>& w,
int dd);
299 bool compute_U_V_bases(
int d,
const opts_t& opts,
300 WorkCompressMPI<scalar_t>& w);
301 void compute_U_basis_stable(
const opts_t& opts,
302 WorkCompressMPI<scalar_t>& w,
304 void compute_V_basis_stable(
const opts_t& opts,
305 WorkCompressMPI<scalar_t>& w,
307 bool update_orthogonal_basis(
const opts_t& opts,
308 scalar_t& r_max_0,
const DistM_t& S,
309 DistM_t& Q,
int d,
int dd,
310 bool untouched,
int L);
311 void reduce_local_samples(
const DistSamples<scalar_t>& RS,
312 WorkCompressMPI<scalar_t>& w,
313 int dd,
bool was_compressed);
314 void communicate_child_data(WorkCompressMPI<scalar_t>& w);
315 void notify_inactives_J(WorkCompressMPI<scalar_t>& w);
316 void notify_inactives_J(WorkCompressMPIANN<scalar_t>& w);
317 void notify_inactives_states(WorkCompressMPI<scalar_t>& w);
319 void compress_level_original(DistSamples<scalar_t>& RS,
321 WorkCompressMPI<scalar_t>& w,
322 int dd,
int lvl)
override;
323 void compress_level_stable(DistSamples<scalar_t>& RS,
325 WorkCompressMPI<scalar_t>& w,
326 int d,
int dd,
int lvl)
override;
327 void extract_level(
const delemw_t& Aelem,
const opts_t& opts,
328 WorkCompressMPI<scalar_t>& w,
int lvl);
329 void extract_level(
const delem_blocks_t& Aelem,
const opts_t& opts,
330 WorkCompressMPI<scalar_t>& w,
int lvl);
331 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
332 std::vector<std::vector<std::size_t>>& J,
333 WorkCompressMPI<scalar_t>& w,
334 int& self,
int lvl)
override;
335 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
336 std::vector<std::vector<std::size_t>>& J,
337 std::vector<DistMW_t>& B,
339 WorkCompressMPI<scalar_t>& w,
340 int& self,
int lvl)
override;
341 void allgather_extraction_indices(std::vector<std::vector<std::size_t>>& lI,
342 std::vector<std::vector<std::size_t>>& lJ,
343 std::vector<std::vector<std::size_t>>& I,
344 std::vector<std::vector<std::size_t>>& J,
345 int& before,
int self,
int& after);
346 void extract_D_B(
const delemw_t& Aelem,
348 WorkCompressMPI<scalar_t>& w,
int lvl)
override;
350 void factor_recursive(WorkFactorMPI<scalar_t>& w,
352 bool isroot,
bool partial)
override;
354 void solve_fwd(
const DistSubLeaf<scalar_t>& b,
355 WorkSolveMPI<scalar_t>& w,
356 bool partial,
bool isroot)
const override;
357 void solve_bwd(DistSubLeaf<scalar_t>& x,
358 WorkSolveMPI<scalar_t>& w,
bool isroot)
const override;
360 void apply_fwd(
const DistSubLeaf<scalar_t>& B,
361 WorkApplyMPI<scalar_t>& w,
362 bool isroot,
long long int flops)
const override;
363 void apply_bwd(
const DistSubLeaf<scalar_t>& B, scalar_t beta,
364 DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
365 bool isroot,
long long int flops)
const override;
366 void applyT_fwd(
const DistSubLeaf<scalar_t>& B,
367 WorkApplyMPI<scalar_t>& w,
368 bool isroot,
long long int flops)
const override;
369 void applyT_bwd(
const DistSubLeaf<scalar_t>& B, scalar_t beta,
370 DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
371 bool isroot,
long long int flops)
const override;
373 void extract_fwd(WorkExtractMPI<scalar_t>& w,
const BLACSGrid* lg,
374 bool odiag)
const override;
375 void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
377 WorkExtractMPI<scalar_t>& w)
const override;
378 void triplets_to_DistM(std::vector<Triplet<scalar_t>>& triplets,
380 void extract_fwd(WorkExtractBlocksMPI<scalar_t>& w,
382 std::vector<bool>& odiag)
const override;
383 void extract_bwd(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
385 WorkExtractBlocksMPI<scalar_t>& w)
const override;
386 void triplets_to_DistM(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
387 std::vector<DistM_t>& B)
const;
389 void redistribute_to_tree_to_buffers(
const DistM_t& A,
390 std::size_t Arlo, std::size_t Aclo,
391 std::vector<std::vector<scalar_t>>& sbuf,
392 int dest=0)
override;
393 void redistribute_to_tree_from_buffers(
const DistM_t& A,
394 std::size_t rlo, std::size_t clo,
395 std::vector<scalar_t*>& pbuf)
397 void delete_redistributed_input()
override;
399 void apply_UV_big(DistSubLeaf<scalar_t>& Theta, DistM_t& Uop,
400 DistSubLeaf<scalar_t>& Phi, DistM_t& Vop,
401 long long int& flops)
const override;
403 static int Pl(std::size_t n, std::size_t nl, std::size_t nr,
int P) {
405 (1, std::min(
int(std::round(
float(P) * nl / n)), P-1));
407 static int Pr(std::size_t n, std::size_t nl, std::size_t nr,
int P) {
408 return std::max(1, P - Pl(n, nl, nr, P));
411 return Pl(this->
rows(), child(0)->
rows(),
412 child(1)->
rows(), Ptotal());
415 return Pr(this->
rows(), child(0)->
rows(),
416 child(1)->
rows(), Ptotal());
419 template<
typename T>
friend
420 void apply_HSS(
Trans ta,
const HSSMatrixMPI<T>& a,
421 const DistributedMatrix<T>& b, T beta,
422 DistributedMatrix<T>& c);
423 friend class DistSamples<scalar_t>;
This file contains the HSSMatrix class definition as well as implementations for a number of it's mem...
Definitions of several kernel functions, and helper routines. Also provides driver routines for kerne...
Contains some simple C++ MPI wrapper utilities.
This is a small wrapper class around a BLACS grid and a BLACS context.
Definition: BLACSGrid.hpp:66
const MPIComm & Comm() const
Definition: BLACSGrid.hpp:196
int npactives() const
Definition: BLACSGrid.hpp:257
int P() const
Definition: BLACSGrid.hpp:251
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition: DenseMatrix.hpp:1015
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition: DenseMatrix.hpp:138
Definition: DistributedMatrix.hpp:733
2D block cyclicly distributed matrix, as used by ScaLAPACK.
Definition: DistributedMatrix.hpp:84
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition: HSSMatrixBase.hpp:83
std::size_t rows() const override
Definition: HSSMatrixBase.hpp:163
HSSMatrixBase(std::size_t m, std::size_t n, bool active)
Distributed memory implementation of the HSS (Hierarchically Semi-Separable) matrix format.
Definition: HSSMatrixMPI.hpp:68
void mult(Trans op, const DistM_t &x, DistM_t &y) const override
void shift(scalar_t sigma) override
std::size_t nonzeros() const override
std::size_t levels() const override
void solve(DistM_t &b) const override
std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const override
void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const override
std::size_t rank() const override
std::size_t memory() const override
Class containing several options for the HSS code and data-structures.
Definition: HSSOptions.hpp:152
Definition: HSSExtraMPI.hpp:134
Wrapper class around an MPI_Comm object.
Definition: MPIWrapper.hpp:194
MPI_Comm comm() const
Definition: MPIWrapper.hpp:261
Representation of a kernel matrix.
Definition: Kernel.hpp:73
The cluster tree, or partition tree that represents the partitioning of the rows or columns of a hier...
Definition: ClusterTree.hpp:67
Definition: StrumpackOptions.hpp:43
Trans
Definition: DenseMatrix.hpp:51