Loading...
Searching...
No Matches
HSSMatrixMPI.hpp
Go to the documentation of this file.
1/*
2 * STRUMPACK -- STRUctured Matrices PACKage, Copyright (c) 2014, The
3 * Regents of the University of California, through Lawrence Berkeley
4 * National Laboratory (subject to receipt of any required approvals
5 * from the U.S. Dept. of Energy). All rights reserved.
6 *
7 * If you have questions about your rights to use or distribute this
8 * software, please contact Berkeley Lab's Technology Transfer
9 * Department at TTD@lbl.gov.
10 *
11 * NOTICE. This software is owned by the U.S. Department of Energy. As
12 * such, the U.S. Government has been granted for itself and others
13 * acting on its behalf a paid-up, nonexclusive, irrevocable,
14 * worldwide license in the Software to reproduce, prepare derivative
15 * works, and perform publicly and display publicly. Beginning five
16 * (5) years after the date permission to assert copyright is obtained
17 * from the U.S. Department of Energy, and subject to any subsequent
18 * five (5) year renewals, the U.S. Government is granted for itself
19 * and others acting on its behalf a paid-up, nonexclusive,
20 * irrevocable, worldwide license in the Software to reproduce,
21 * prepare derivative works, distribute copies to the public, perform
22 * publicly and display publicly, and to permit others to do so.
23 *
24 * Developers: Pieter Ghysels, Francois-Henry Rouet, Xiaoye S. Li.
25 * (Lawrence Berkeley National Lab, Computational Research
26 * Division).
27 */
36#ifndef HSS_MATRIX_MPI_HPP
37#define HSS_MATRIX_MPI_HPP
38
39#include <cassert>
40
41#include "HSSMatrix.hpp"
42#include "misc/MPIWrapper.hpp"
43#include "HSSExtraMPI.hpp"
44#include "DistSamples.hpp"
45#include "DistElemMult.hpp"
46#include "HSSBasisIDMPI.hpp"
47#include "kernel/Kernel.hpp"
48
49namespace strumpack {
50 namespace HSS {
51
67 template<typename scalar_t> class HSSMatrixMPI
68 : public HSSMatrixBase<scalar_t> {
69 using real_t = typename RealType<scalar_t>::value_type;
74 using delem_t = typename std::function
75 <void(const std::vector<std::size_t>& I,
76 const std::vector<std::size_t>& J, DistM_t& B)>;
77 using delem_blocks_t = typename std::function
78 <void(const std::vector<std::vector<std::size_t>>& I,
79 const std::vector<std::vector<std::size_t>>& J,
80 std::vector<DistMW_t>& B)>;
81 using elem_t = typename std::function
82 <void(const std::vector<std::size_t>& I,
83 const std::vector<std::size_t>& J, DenseM_t& B)>;
84 using dmult_t = typename std::function
85 <void(DistM_t& R, DistM_t& Sr, DistM_t& Sc)>;
87
88 public:
89 HSSMatrixMPI() : HSSMatrixBase<scalar_t>(0, 0, true) {}
90 HSSMatrixMPI(const DistM_t& A, const opts_t& opts);
92 const DistM_t& A, const opts_t& opts);
94 const BLACSGrid* g, const opts_t& opts);
95 HSSMatrixMPI(std::size_t m, std::size_t n, const BLACSGrid* Agrid,
96 const dmult_t& Amult, const delem_t& Aelem,
97 const opts_t& opts);
98 HSSMatrixMPI(std::size_t m, std::size_t n, const BLACSGrid* Agrid,
99 const dmult_t& Amult, const delem_blocks_t& Aelem,
100 const opts_t& opts);
101 HSSMatrixMPI(const structured::ClusterTree& t, const BLACSGrid* Agrid,
102 const dmult_t& Amult, const delem_t& Aelem,
103 const opts_t& opts);
105 const opts_t& opts);
107 HSSMatrixMPI(HSSMatrixMPI<scalar_t>&& other) = default;
108 virtual ~HSSMatrixMPI() {}
109
110 HSSMatrixMPI<scalar_t>& operator=(const HSSMatrixMPI<scalar_t>& other);
111 HSSMatrixMPI<scalar_t>& operator=(HSSMatrixMPI<scalar_t>&& other) = default;
112 std::unique_ptr<HSSMatrixBase<scalar_t>> clone() const override;
113
114 const HSSMatrixBase<scalar_t>* child(int c) const {
115 return this->ch_[c].get();
116 }
117 HSSMatrixBase<scalar_t>* child(int c) { return this->ch_[c].get(); }
118
119 const BLACSGrid* grid() const override { return blacs_grid_; }
120 const BLACSGrid* grid(const BLACSGrid* grid) const override { return blacs_grid_; }
121 const BLACSGrid* grid_local() const override { return blacs_grid_local_; }
122 const MPIComm& Comm() const { return grid()->Comm(); }
123 MPI_Comm comm() const { return Comm().comm(); }
124 int Ptotal() const override { return grid()->P(); }
125 int Pactive() const override { return grid()->npactives(); }
126
127
128 void compress(const DistM_t& A,
129 const opts_t& opts);
130 void compress(const dmult_t& Amult,
131 const delem_t& Aelem,
132 const opts_t& opts);
133 void compress(const dmult_t& Amult,
134 const delem_blocks_t& Aelem,
135 const opts_t& opts);
136 void compress(const kernel::Kernel<real_t>& K, const opts_t& opts);
137
138 void factor() override;
139 void partial_factor();
140 void solve(DistM_t& b) const override;
141
142 void forward_solve(WorkSolveMPI<scalar_t>& w, const DistM_t& b,
143 bool partial) const override;
144 void backward_solve(WorkSolveMPI<scalar_t>& w,
145 DistM_t& x) const override;
146
147 DistM_t apply(const DistM_t& b) const;
148 DistM_t applyC(const DistM_t& b) const;
149
150 void mult(Trans op, const DistM_t& x, DistM_t& y) const override;
151
152 scalar_t get(std::size_t i, std::size_t j) const;
153 DistM_t extract(const std::vector<std::size_t>& I,
154 const std::vector<std::size_t>& J,
155 const BLACSGrid* Bgrid) const;
156 std::vector<DistM_t>
157 extract(const std::vector<std::vector<std::size_t>>& I,
158 const std::vector<std::vector<std::size_t>>& J,
159 const BLACSGrid* Bgrid) const;
160 void extract_add(const std::vector<std::size_t>& I,
161 const std::vector<std::size_t>& J, DistM_t& B) const;
162 void extract_add(const std::vector<std::vector<std::size_t>>& I,
163 const std::vector<std::vector<std::size_t>>& J,
164 std::vector<DistM_t>& B) const;
165
166 void Schur_update(DistM_t& Theta, DistM_t& Vhat,
167 DistM_t& DUB01, DistM_t& Phi) const;
168 void Schur_product_direct(const DistM_t& Theta,
169 const DistM_t& Vhat,
170 const DistM_t& DUB01,
171 const DistM_t& Phi,
172 const DistM_t&_ThetaVhatC,
173 const DistM_t& VhatCPhiC,
174 const DistM_t& R,
175 DistM_t& Sr, DistM_t& Sc) const;
176
177 std::size_t max_rank() const; // collective on comm()
178 std::size_t total_memory() const; // collective on comm()
179 std::size_t total_nonzeros() const; // collective on comm()
180 std::size_t total_factor_nonzeros() const; // collective on comm()
181 std::size_t max_levels() const; // collective on comm()
182 std::size_t rank() const override;
183 std::size_t memory() const override;
184 std::size_t nonzeros() const override;
185 std::size_t factor_nonzeros() const override;
186 std::size_t levels() const override;
187
188 void print_info(std::ostream &out=std::cout,
189 std::size_t roff=0,
190 std::size_t coff=0) const override;
191
192 DistM_t dense() const;
193
194 void shift(scalar_t sigma) override;
195
196 const TreeLocalRanges& tree_ranges() const { return ranges_; }
197 void to_block_row(const DistM_t& A,
198 DenseM_t& sub_A,
199 DistM_t& leaf_A) const override;
200 void allocate_block_row(int d, DenseM_t& sub_A,
201 DistM_t& leaf_A) const override;
202 void from_block_row(DistM_t& A,
203 const DenseM_t& sub_A,
204 const DistM_t& leaf_A,
205 const BLACSGrid* lgrid) const override;
206
207 void delete_trailing_block() override;
208 void reset() override;
209
210 const HSSFactorsMPI<scalar_t>& ULV() { return this->ULV_mpi_; }
211
212 private:
213 using delemw_t = typename std::function
214 <void(const std::vector<std::size_t>& I,
215 const std::vector<std::size_t>& J,
216 DistM_t& B, DistM_t& A,
217 std::size_t rlo, std::size_t clo,
218 MPI_Comm comm)>;
219
220 const BLACSGrid* blacs_grid_;
221 const BLACSGrid* blacs_grid_local_;
222 std::unique_ptr<const BLACSGrid> owned_blacs_grid_;
223 std::unique_ptr<const BLACSGrid> owned_blacs_grid_local_;
224
225 TreeLocalRanges ranges_;
226
227 HSSBasisIDMPI<scalar_t> U_, V_;
228 DistM_t D_, B01_, B10_;
229
230 // Used to redistribute the original 2D block cyclic matrix
231 // according to the HSS tree
232 DistM_t A_, A01_, A10_;
233
234 HSSMatrixMPI(std::size_t m, std::size_t n, const opts_t& opts,
235 const MPIComm& c, int P,
236 std::size_t roff, std::size_t coff);
237 HSSMatrixMPI(const structured::ClusterTree& t, const opts_t& opts,
238 const MPIComm& c, int P,
239 std::size_t roff, std::size_t coff);
240 void setup_hierarchy(const opts_t& opts,
241 std::size_t roff, std::size_t coff);
242 void setup_hierarchy(const structured::ClusterTree& t, const opts_t& opts,
243 std::size_t roff, std::size_t coff);
244 void setup_local_context();
245 void setup_ranges(std::size_t roff, std::size_t coff);
246
247 void compress_original_nosync(const dmult_t& Amult,
248 const delemw_t& Aelem,
249 const opts_t& opts);
250 void compress_original_sync(const dmult_t& Amult,
251 const delemw_t& Aelem,
252 const opts_t& opts);
253 void compress_original_sync(const dmult_t& Amult,
254 const delem_blocks_t& Aelem,
255 const opts_t& opts);
256 void compress_stable_nosync(const dmult_t& Amult,
257 const delemw_t& Aelem,
258 const opts_t& opts);
259 void compress_stable_sync(const dmult_t& Amult,
260 const delemw_t& Aelem,
261 const opts_t& opts);
262 void compress_stable_sync(const dmult_t& Amult,
263 const delem_blocks_t& Aelem,
264 const opts_t& opts);
265 void compress_hard_restart_nosync(const dmult_t& Amult,
266 const delemw_t& Aelem,
267 const opts_t& opts);
268 void compress_hard_restart_sync(const dmult_t& Amult,
269 const delemw_t& Aelem,
270 const opts_t& opts);
271 void compress_hard_restart_sync(const dmult_t& Amult,
272 const delem_blocks_t& Aelem,
273 const opts_t& opts);
274
275 void compress_recursive_ann(DenseMatrix<std::uint32_t>& ann,
276 DenseMatrix<real_t>& scores,
277 const delemw_t& Aelem,
278 WorkCompressMPIANN<scalar_t>& w,
279 const opts_t& opts,
280 const BLACSGrid* lg) override;
281 void compute_local_samples_ann(DenseMatrix<std::uint32_t>& ann,
282 DenseMatrix<real_t>& scores,
283 WorkCompressMPIANN<scalar_t>& w,
284 const delemw_t& Aelem,
285 const opts_t& opts);
286 bool compute_U_V_bases_ann(DistM_t& S, const opts_t& opts,
287 WorkCompressMPIANN<scalar_t>& w);
288 void communicate_child_data_ann(WorkCompressMPIANN<scalar_t>& w);
289
290 void compress_recursive_original(DistSamples<scalar_t>& RS,
291 const delemw_t& Aelem,
292 const opts_t& opts,
293 WorkCompressMPI<scalar_t>& w,
294 int dd) override;
295 void compress_recursive_stable(DistSamples<scalar_t>& RS,
296 const delemw_t& Aelem,
297 const opts_t& opts,
298 WorkCompressMPI<scalar_t>& w,
299 int d, int dd) override;
300 void compute_local_samples(const DistSamples<scalar_t>& RS,
301 WorkCompressMPI<scalar_t>& w, int dd);
302 bool compute_U_V_bases(int d, const opts_t& opts,
303 WorkCompressMPI<scalar_t>& w);
304 void compute_U_basis_stable(const opts_t& opts,
305 WorkCompressMPI<scalar_t>& w,
306 int d, int dd);
307 void compute_V_basis_stable(const opts_t& opts,
308 WorkCompressMPI<scalar_t>& w,
309 int d, int dd);
310 bool update_orthogonal_basis(const opts_t& opts,
311 scalar_t& r_max_0, const DistM_t& S,
312 DistM_t& Q, int d, int dd,
313 bool untouched, int L);
314 void reduce_local_samples(const DistSamples<scalar_t>& RS,
315 WorkCompressMPI<scalar_t>& w,
316 int dd, bool was_compressed);
317 void communicate_child_data(WorkCompressMPI<scalar_t>& w);
318 void notify_inactives_J(WorkCompressMPI<scalar_t>& w);
319 void notify_inactives_J(WorkCompressMPIANN<scalar_t>& w);
320 void notify_inactives_states(WorkCompressMPI<scalar_t>& w);
321
322 void compress_level_original(DistSamples<scalar_t>& RS,
323 const opts_t& opts,
324 WorkCompressMPI<scalar_t>& w,
325 int dd, int lvl) override;
326 void compress_level_stable(DistSamples<scalar_t>& RS,
327 const opts_t& opts,
328 WorkCompressMPI<scalar_t>& w,
329 int d, int dd, int lvl) override;
330 void extract_level(const delemw_t& Aelem, const opts_t& opts,
331 WorkCompressMPI<scalar_t>& w, int lvl);
332 void extract_level(const delem_blocks_t& Aelem, const opts_t& opts,
333 WorkCompressMPI<scalar_t>& w, int lvl);
334
335 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
336 std::vector<std::vector<std::size_t>>& J,
337 WorkCompressMPI<scalar_t>& w,
338 int& self, int lvl) override;
339 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
340 std::vector<std::vector<std::size_t>>& J,
341 std::vector<DistMW_t>& B,
342 const BLACSGrid* lg,
343 WorkCompressMPI<scalar_t>& w,
344 int& self, int lvl) override;
345 void allgather_extraction_indices(std::vector<std::vector<std::size_t>>& lI,
346 std::vector<std::vector<std::size_t>>& lJ,
347 std::vector<std::vector<std::size_t>>& I,
348 std::vector<std::vector<std::size_t>>& J,
349 int& before, int self, int& after);
350
351 void extract_D_B(const delemw_t& Aelem,
352 const BLACSGrid* lg, const opts_t& opts,
353 WorkCompressMPI<scalar_t>& w, int lvl) override;
354
355 void factor_recursive(WorkFactorMPI<scalar_t>& w,
356 const BLACSGrid* lg,
357 bool isroot, bool partial) override;
358
359 void solve_fwd(const DistSubLeaf<scalar_t>& b,
360 WorkSolveMPI<scalar_t>& w,
361 bool partial, bool isroot) const override;
362 void solve_bwd(DistSubLeaf<scalar_t>& x,
363 WorkSolveMPI<scalar_t>& w, bool isroot) const override;
364
365 void apply_fwd(const DistSubLeaf<scalar_t>& B,
366 WorkApplyMPI<scalar_t>& w,
367 bool isroot, long long int flops) const override;
368 void apply_bwd(const DistSubLeaf<scalar_t>& B, scalar_t beta,
369 DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
370 bool isroot, long long int flops) const override;
371 void applyT_fwd(const DistSubLeaf<scalar_t>& B,
372 WorkApplyMPI<scalar_t>& w,
373 bool isroot, long long int flops) const override;
374 void applyT_bwd(const DistSubLeaf<scalar_t>& B, scalar_t beta,
375 DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
376 bool isroot, long long int flops) const override;
377
378 void extract_fwd(WorkExtractMPI<scalar_t>& w, const BLACSGrid* lg,
379 bool odiag) const override;
380 void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
381 const BLACSGrid* lg,
382 WorkExtractMPI<scalar_t>& w) const override;
383 void triplets_to_DistM(std::vector<Triplet<scalar_t>>& triplets,
384 DistM_t& B) const;
385 void extract_fwd(WorkExtractBlocksMPI<scalar_t>& w,
386 const BLACSGrid* lg,
387 std::vector<bool>& odiag) const override;
388 void extract_bwd(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
389 const BLACSGrid* lg,
390 WorkExtractBlocksMPI<scalar_t>& w) const override;
391 void triplets_to_DistM(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
392 std::vector<DistM_t>& B) const;
393
394 void redistribute_to_tree_to_buffers(const DistM_t& A,
395 std::size_t Arlo, std::size_t Aclo,
396 std::vector<std::vector<scalar_t>>& sbuf,
397 int dest=0) override;
398 void redistribute_to_tree_from_buffers(const DistM_t& A,
399 std::size_t rlo, std::size_t clo,
400 std::vector<scalar_t*>& pbuf)
401 override;
402 void delete_redistributed_input() override;
403
404 void apply_UV_big(DistSubLeaf<scalar_t>& Theta, DistM_t& Uop,
405 DistSubLeaf<scalar_t>& Phi, DistM_t& Vop,
406 long long int& flops) const override;
407
408 static int Pl(std::size_t n, std::size_t nl, std::size_t nr, int P) {
409 return std::max
410 (1, std::min(int(std::round(float(P) * nl / n)), P-1));
411 }
412 static int Pr(std::size_t n, std::size_t nl, std::size_t nr, int P) {
413 return std::max(1, P - Pl(n, nl, nr, P));
414 }
415 int Pl() const {
416 return Pl(this->rows(), child(0)->rows(),
417 child(1)->rows(), Ptotal());
418 }
419 int Pr() const {
420 return Pr(this->rows(), child(0)->rows(),
421 child(1)->rows(), Ptotal());
422 }
423
424 template<typename T> friend
425 void apply_HSS(Trans ta, const HSSMatrixMPI<T>& a,
426 const DistributedMatrix<T>& b, T beta,
427 DistributedMatrix<T>& c);
428 friend class DistSamples<scalar_t>;
429
430 using HSSMatrixBase<scalar_t>::child;
431
432 // suppress warnings
433 using structured::StructuredMatrix<scalar_t>::mult;
434 using structured::StructuredMatrix<scalar_t>::solve;
435 using HSSMatrixBase<scalar_t>::forward_solve;
436 using HSSMatrixBase<scalar_t>::backward_solve;
437 using HSSMatrixBase<scalar_t>::compress_recursive_ann;
438 using HSSMatrixBase<scalar_t>::compress_recursive_original;
439 using HSSMatrixBase<scalar_t>::compress_recursive_stable;
440 using HSSMatrixBase<scalar_t>::compress_level_original;
441 using HSSMatrixBase<scalar_t>::compress_level_stable;
442 using HSSMatrixBase<scalar_t>::get_extraction_indices;
443 using HSSMatrixBase<scalar_t>::extract_D_B;
444 using HSSMatrixBase<scalar_t>::factor_recursive;
445 using HSSMatrixBase<scalar_t>::solve_fwd;
446 using HSSMatrixBase<scalar_t>::solve_bwd;
447 using HSSMatrixBase<scalar_t>::apply_fwd;
448 using HSSMatrixBase<scalar_t>::apply_bwd;
449 using HSSMatrixBase<scalar_t>::applyT_fwd;
450 using HSSMatrixBase<scalar_t>::applyT_bwd;
451 using HSSMatrixBase<scalar_t>::extract_fwd;
452 using HSSMatrixBase<scalar_t>::extract_bwd;
453 using HSSMatrixBase<scalar_t>::apply_UV_big;
454 };
455
456 } // end namespace HSS
457} // end namespace strumpack
458
459#endif // HSS_MATRIX_MPI_HPP
This file contains the HSSMatrix class definition as well as implementations for a number of it's mem...
Definitions of several kernel functions, and helper routines. Also provides driver routines for kerne...
Contains some simple C++ MPI wrapper utilities.
This is a small wrapper class around a BLACS grid and a BLACS context.
Definition BLACSGrid.hpp:66
const MPIComm & Comm() const
Definition BLACSGrid.hpp:196
int npactives() const
Definition BLACSGrid.hpp:257
int P() const
Definition BLACSGrid.hpp:251
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition DenseMatrix.hpp:1018
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition DenseMatrix.hpp:139
Definition DistributedMatrix.hpp:737
2D block cyclicly distributed matrix, as used by ScaLAPACK.
Definition DistributedMatrix.hpp:84
Contains data related to ULV factorization of a distributed HSS matrix.
Definition HSSExtraMPI.hpp:184
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition HSSMatrixBase.hpp:83
std::size_t rows() const override
Definition HSSMatrixBase.hpp:163
HSSMatrixBase(std::size_t m, std::size_t n, bool active)
Distributed memory implementation of the HSS (Hierarchically Semi-Separable) matrix format.
Definition HSSMatrixMPI.hpp:68
void mult(Trans op, const DistM_t &x, DistM_t &y) const override
void shift(scalar_t sigma) override
std::size_t nonzeros() const override
std::size_t levels() const override
void solve(DistM_t &b) const override
std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const override
void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const override
std::size_t rank() const override
std::size_t memory() const override
Class containing several options for the HSS code and data-structures.
Definition HSSOptions.hpp:152
Definition HSSExtraMPI.hpp:134
Wrapper class around an MPI_Comm object.
Definition MPIWrapper.hpp:173
MPI_Comm comm() const
Definition MPIWrapper.hpp:240
Representation of a kernel matrix.
Definition Kernel.hpp:73
The cluster tree, or partition tree that represents the partitioning of the rows or columns of a hier...
Definition ClusterTree.hpp:67
Definition StrumpackOptions.hpp:44
Trans
Definition DenseMatrix.hpp:51