HSSMatrixMPI.hpp
Go to the documentation of this file.
1 /*
2  * STRUMPACK -- STRUctured Matrices PACKage, Copyright (c) 2014, The
3  * Regents of the University of California, through Lawrence Berkeley
4  * National Laboratory (subject to receipt of any required approvals
5  * from the U.S. Dept. of Energy). All rights reserved.
6  *
7  * If you have questions about your rights to use or distribute this
8  * software, please contact Berkeley Lab's Technology Transfer
9  * Department at TTD@lbl.gov.
10  *
11  * NOTICE. This software is owned by the U.S. Department of Energy. As
12  * such, the U.S. Government has been granted for itself and others
13  * acting on its behalf a paid-up, nonexclusive, irrevocable,
14  * worldwide license in the Software to reproduce, prepare derivative
15  * works, and perform publicly and display publicly. Beginning five
16  * (5) years after the date permission to assert copyright is obtained
17  * from the U.S. Department of Energy, and subject to any subsequent
18  * five (5) year renewals, the U.S. Government is granted for itself
19  * and others acting on its behalf a paid-up, nonexclusive,
20  * irrevocable, worldwide license in the Software to reproduce,
21  * prepare derivative works, distribute copies to the public, perform
22  * publicly and display publicly, and to permit others to do so.
23  *
24  * Developers: Pieter Ghysels, Francois-Henry Rouet, Xiaoye S. Li.
25  * (Lawrence Berkeley National Lab, Computational Research
26  * Division).
27  */
36 #ifndef HSS_MATRIX_MPI_HPP
37 #define HSS_MATRIX_MPI_HPP
38 
39 #include <cassert>
40 
41 #include "HSSMatrix.hpp"
42 #include "misc/MPIWrapper.hpp"
43 #include "HSSExtraMPI.hpp"
44 #include "DistSamples.hpp"
45 #include "DistElemMult.hpp"
46 #include "HSSBasisIDMPI.hpp"
47 #include "kernel/Kernel.hpp"
48 
49 namespace strumpack {
50  namespace HSS {
51 
67  template<typename scalar_t> class HSSMatrixMPI
68  : public HSSMatrixBase<scalar_t> {
69  using real_t = typename RealType<scalar_t>::value_type;
74  using delem_t = typename std::function
75  <void(const std::vector<std::size_t>& I,
76  const std::vector<std::size_t>& J, DistM_t& B)>;
77  using delem_blocks_t = typename std::function
78  <void(const std::vector<std::vector<std::size_t>>& I,
79  const std::vector<std::vector<std::size_t>>& J,
80  std::vector<DistMW_t>& B)>;
81  using elem_t = typename std::function
82  <void(const std::vector<std::size_t>& I,
83  const std::vector<std::size_t>& J, DenseM_t& B)>;
84  using dmult_t = typename std::function
85  <void(DistM_t& R, DistM_t& Sr, DistM_t& Sc)>;
87 
88  public:
89  HSSMatrixMPI() : HSSMatrixBase<scalar_t>(0, 0, true) {}
90  HSSMatrixMPI(const DistM_t& A, const opts_t& opts);
92  const DistM_t& A, const opts_t& opts);
94  const BLACSGrid* g, const opts_t& opts);
95  HSSMatrixMPI(std::size_t m, std::size_t n, const BLACSGrid* Agrid,
96  const dmult_t& Amult, const delem_t& Aelem,
97  const opts_t& opts);
98  HSSMatrixMPI(std::size_t m, std::size_t n, const BLACSGrid* Agrid,
99  const dmult_t& Amult, const delem_blocks_t& Aelem,
100  const opts_t& opts);
101  HSSMatrixMPI(const structured::ClusterTree& t, const BLACSGrid* Agrid,
102  const dmult_t& Amult, const delem_t& Aelem,
103  const opts_t& opts);
105  const opts_t& opts);
106  HSSMatrixMPI(const HSSMatrixMPI<scalar_t>& other);
107  HSSMatrixMPI(HSSMatrixMPI<scalar_t>&& other) = default;
108  virtual ~HSSMatrixMPI() {}
109 
110  HSSMatrixMPI<scalar_t>& operator=(const HSSMatrixMPI<scalar_t>& other);
111  HSSMatrixMPI<scalar_t>& operator=(HSSMatrixMPI<scalar_t>&& other) = default;
112  std::unique_ptr<HSSMatrixBase<scalar_t>> clone() const override;
113 
114  const HSSMatrixBase<scalar_t>* child(int c) const {
115  return this->ch_[c].get();
116  }
117  HSSMatrixBase<scalar_t>* child(int c) { return this->ch_[c].get(); }
118 
119  const BLACSGrid* grid() const override { return blacs_grid_; }
120  const BLACSGrid* grid(const BLACSGrid* grid) const override { return blacs_grid_; }
121  const BLACSGrid* grid_local() const override { return blacs_grid_local_; }
122  const MPIComm& Comm() const { return grid()->Comm(); }
123  MPI_Comm comm() const { return Comm().comm(); }
124  int Ptotal() const override { return grid()->P(); }
125  int Pactive() const override { return grid()->npactives(); }
126 
127 
128  void compress(const DistM_t& A,
129  const opts_t& opts);
130  void compress(const dmult_t& Amult,
131  const delem_t& Aelem,
132  const opts_t& opts);
133  void compress(const dmult_t& Amult,
134  const delem_blocks_t& Aelem,
135  const opts_t& opts);
136  void compress(const kernel::Kernel<real_t>& K, const opts_t& opts);
137 
138  void factor() override;
139  void partial_factor();
140  void solve(DistM_t& b) const override;
141  void forward_solve(WorkSolveMPI<scalar_t>& w, const DistM_t& b,
142  bool partial) const override;
143  void backward_solve(WorkSolveMPI<scalar_t>& w,
144  DistM_t& x) const override;
145 
146  DistM_t apply(const DistM_t& b) const;
147  DistM_t applyC(const DistM_t& b) const;
148 
149  void mult(Trans op, const DistM_t& x, DistM_t& y) const override;
150 
151  scalar_t get(std::size_t i, std::size_t j) const;
152  DistM_t extract(const std::vector<std::size_t>& I,
153  const std::vector<std::size_t>& J,
154  const BLACSGrid* Bgrid) const;
155  std::vector<DistM_t>
156  extract(const std::vector<std::vector<std::size_t>>& I,
157  const std::vector<std::vector<std::size_t>>& J,
158  const BLACSGrid* Bgrid) const;
159  void extract_add(const std::vector<std::size_t>& I,
160  const std::vector<std::size_t>& J, DistM_t& B) const;
161  void extract_add(const std::vector<std::vector<std::size_t>>& I,
162  const std::vector<std::vector<std::size_t>>& J,
163  std::vector<DistM_t>& B) const;
164 
165  void Schur_update(DistM_t& Theta, DistM_t& Vhat,
166  DistM_t& DUB01, DistM_t& Phi) const;
167  void Schur_product_direct(const DistM_t& Theta,
168  const DistM_t& Vhat,
169  const DistM_t& DUB01,
170  const DistM_t& Phi,
171  const DistM_t&_ThetaVhatC,
172  const DistM_t& VhatCPhiC,
173  const DistM_t& R,
174  DistM_t& Sr, DistM_t& Sc) const;
175 
176  std::size_t max_rank() const; // collective on comm()
177  std::size_t total_memory() const; // collective on comm()
178  std::size_t total_nonzeros() const; // collective on comm()
179  std::size_t total_factor_nonzeros() const; // collective on comm()
180  std::size_t max_levels() const; // collective on comm()
181  std::size_t rank() const override;
182  std::size_t memory() const override;
183  std::size_t nonzeros() const override;
184  std::size_t factor_nonzeros() const override;
185  std::size_t levels() const override;
186 
187  void print_info(std::ostream &out=std::cout,
188  std::size_t roff=0,
189  std::size_t coff=0) const override;
190 
191  DistM_t dense() const;
192 
193  void shift(scalar_t sigma) override;
194 
195  const TreeLocalRanges& tree_ranges() const { return ranges_; }
196  void to_block_row(const DistM_t& A,
197  DenseM_t& sub_A,
198  DistM_t& leaf_A) const override;
199  void allocate_block_row(int d, DenseM_t& sub_A,
200  DistM_t& leaf_A) const override;
201  void from_block_row(DistM_t& A,
202  const DenseM_t& sub_A,
203  const DistM_t& leaf_A,
204  const BLACSGrid* lgrid) const override;
205 
206  void delete_trailing_block() override;
207  void reset() override;
208 
209  private:
210  using delemw_t = typename std::function
211  <void(const std::vector<std::size_t>& I,
212  const std::vector<std::size_t>& J,
213  DistM_t& B, DistM_t& A,
214  std::size_t rlo, std::size_t clo,
215  MPI_Comm comm)>;
216 
217  const BLACSGrid* blacs_grid_;
218  const BLACSGrid* blacs_grid_local_;
219  std::unique_ptr<const BLACSGrid> owned_blacs_grid_;
220  std::unique_ptr<const BLACSGrid> owned_blacs_grid_local_;
221 
222  TreeLocalRanges ranges_;
223 
224  HSSBasisIDMPI<scalar_t> U_, V_;
225  DistM_t D_, B01_, B10_;
226 
227  // Used to redistribute the original 2D block cyclic matrix
228  // according to the HSS tree
229  DistM_t A_, A01_, A10_;
230 
231  HSSMatrixMPI(std::size_t m, std::size_t n, const opts_t& opts,
232  const MPIComm& c, int P,
233  std::size_t roff, std::size_t coff);
234  HSSMatrixMPI(const structured::ClusterTree& t, const opts_t& opts,
235  const MPIComm& c, int P,
236  std::size_t roff, std::size_t coff);
237  void setup_hierarchy(const opts_t& opts,
238  std::size_t roff, std::size_t coff);
239  void setup_hierarchy(const structured::ClusterTree& t, const opts_t& opts,
240  std::size_t roff, std::size_t coff);
241  void setup_local_context();
242  void setup_ranges(std::size_t roff, std::size_t coff);
243 
244  void compress_original_nosync(const dmult_t& Amult,
245  const delemw_t& Aelem,
246  const opts_t& opts);
247  void compress_original_sync(const dmult_t& Amult,
248  const delemw_t& Aelem,
249  const opts_t& opts);
250  void compress_original_sync(const dmult_t& Amult,
251  const delem_blocks_t& Aelem,
252  const opts_t& opts);
253  void compress_stable_nosync(const dmult_t& Amult,
254  const delemw_t& Aelem,
255  const opts_t& opts);
256  void compress_stable_sync(const dmult_t& Amult,
257  const delemw_t& Aelem,
258  const opts_t& opts);
259  void compress_stable_sync(const dmult_t& Amult,
260  const delem_blocks_t& Aelem,
261  const opts_t& opts);
262  void compress_hard_restart_nosync(const dmult_t& Amult,
263  const delemw_t& Aelem,
264  const opts_t& opts);
265  void compress_hard_restart_sync(const dmult_t& Amult,
266  const delemw_t& Aelem,
267  const opts_t& opts);
268  void compress_hard_restart_sync(const dmult_t& Amult,
269  const delem_blocks_t& Aelem,
270  const opts_t& opts);
271 
272  void compress_recursive_ann(DenseMatrix<std::uint32_t>& ann,
273  DenseMatrix<real_t>& scores,
274  const delemw_t& Aelem,
275  WorkCompressMPIANN<scalar_t>& w,
276  const opts_t& opts,
277  const BLACSGrid* lg) override;
278  void compute_local_samples_ann(DenseMatrix<std::uint32_t>& ann,
279  DenseMatrix<real_t>& scores,
280  WorkCompressMPIANN<scalar_t>& w,
281  const delemw_t& Aelem,
282  const opts_t& opts);
283  bool compute_U_V_bases_ann(DistM_t& S, const opts_t& opts,
284  WorkCompressMPIANN<scalar_t>& w);
285  void communicate_child_data_ann(WorkCompressMPIANN<scalar_t>& w);
286 
287  void compress_recursive_original(DistSamples<scalar_t>& RS,
288  const delemw_t& Aelem,
289  const opts_t& opts,
290  WorkCompressMPI<scalar_t>& w,
291  int dd) override;
292  void compress_recursive_stable(DistSamples<scalar_t>& RS,
293  const delemw_t& Aelem,
294  const opts_t& opts,
295  WorkCompressMPI<scalar_t>& w,
296  int d, int dd) override;
297  void compute_local_samples(const DistSamples<scalar_t>& RS,
298  WorkCompressMPI<scalar_t>& w, int dd);
299  bool compute_U_V_bases(int d, const opts_t& opts,
300  WorkCompressMPI<scalar_t>& w);
301  void compute_U_basis_stable(const opts_t& opts,
302  WorkCompressMPI<scalar_t>& w,
303  int d, int dd);
304  void compute_V_basis_stable(const opts_t& opts,
305  WorkCompressMPI<scalar_t>& w,
306  int d, int dd);
307  bool update_orthogonal_basis(const opts_t& opts,
308  scalar_t& r_max_0, const DistM_t& S,
309  DistM_t& Q, int d, int dd,
310  bool untouched, int L);
311  void reduce_local_samples(const DistSamples<scalar_t>& RS,
312  WorkCompressMPI<scalar_t>& w,
313  int dd, bool was_compressed);
314  void communicate_child_data(WorkCompressMPI<scalar_t>& w);
315  void notify_inactives_J(WorkCompressMPI<scalar_t>& w);
316  void notify_inactives_J(WorkCompressMPIANN<scalar_t>& w);
317  void notify_inactives_states(WorkCompressMPI<scalar_t>& w);
318 
319  void compress_level_original(DistSamples<scalar_t>& RS,
320  const opts_t& opts,
321  WorkCompressMPI<scalar_t>& w,
322  int dd, int lvl) override;
323  void compress_level_stable(DistSamples<scalar_t>& RS,
324  const opts_t& opts,
325  WorkCompressMPI<scalar_t>& w,
326  int d, int dd, int lvl) override;
327  void extract_level(const delemw_t& Aelem, const opts_t& opts,
328  WorkCompressMPI<scalar_t>& w, int lvl);
329  void extract_level(const delem_blocks_t& Aelem, const opts_t& opts,
330  WorkCompressMPI<scalar_t>& w, int lvl);
331  void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
332  std::vector<std::vector<std::size_t>>& J,
333  WorkCompressMPI<scalar_t>& w,
334  int& self, int lvl) override;
335  void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
336  std::vector<std::vector<std::size_t>>& J,
337  std::vector<DistMW_t>& B,
338  const BLACSGrid* lg,
339  WorkCompressMPI<scalar_t>& w,
340  int& self, int lvl) override;
341  void allgather_extraction_indices(std::vector<std::vector<std::size_t>>& lI,
342  std::vector<std::vector<std::size_t>>& lJ,
343  std::vector<std::vector<std::size_t>>& I,
344  std::vector<std::vector<std::size_t>>& J,
345  int& before, int self, int& after);
346  void extract_D_B(const delemw_t& Aelem,
347  const BLACSGrid* lg, const opts_t& opts,
348  WorkCompressMPI<scalar_t>& w, int lvl) override;
349 
350  void factor_recursive(WorkFactorMPI<scalar_t>& w,
351  const BLACSGrid* lg,
352  bool isroot, bool partial) override;
353 
354  void solve_fwd(const DistSubLeaf<scalar_t>& b,
355  WorkSolveMPI<scalar_t>& w,
356  bool partial, bool isroot) const override;
357  void solve_bwd(DistSubLeaf<scalar_t>& x,
358  WorkSolveMPI<scalar_t>& w, bool isroot) const override;
359 
360  void apply_fwd(const DistSubLeaf<scalar_t>& B,
361  WorkApplyMPI<scalar_t>& w,
362  bool isroot, long long int flops) const override;
363  void apply_bwd(const DistSubLeaf<scalar_t>& B, scalar_t beta,
364  DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
365  bool isroot, long long int flops) const override;
366  void applyT_fwd(const DistSubLeaf<scalar_t>& B,
367  WorkApplyMPI<scalar_t>& w,
368  bool isroot, long long int flops) const override;
369  void applyT_bwd(const DistSubLeaf<scalar_t>& B, scalar_t beta,
370  DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
371  bool isroot, long long int flops) const override;
372 
373  void extract_fwd(WorkExtractMPI<scalar_t>& w, const BLACSGrid* lg,
374  bool odiag) const override;
375  void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
376  const BLACSGrid* lg,
377  WorkExtractMPI<scalar_t>& w) const override;
378  void triplets_to_DistM(std::vector<Triplet<scalar_t>>& triplets,
379  DistM_t& B) const;
380  void extract_fwd(WorkExtractBlocksMPI<scalar_t>& w,
381  const BLACSGrid* lg,
382  std::vector<bool>& odiag) const override;
383  void extract_bwd(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
384  const BLACSGrid* lg,
385  WorkExtractBlocksMPI<scalar_t>& w) const override;
386  void triplets_to_DistM(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
387  std::vector<DistM_t>& B) const;
388 
389  void redistribute_to_tree_to_buffers(const DistM_t& A,
390  std::size_t Arlo, std::size_t Aclo,
391  std::vector<std::vector<scalar_t>>& sbuf,
392  int dest=0) override;
393  void redistribute_to_tree_from_buffers(const DistM_t& A,
394  std::size_t rlo, std::size_t clo,
395  std::vector<scalar_t*>& pbuf)
396  override;
397  void delete_redistributed_input() override;
398 
399  void apply_UV_big(DistSubLeaf<scalar_t>& Theta, DistM_t& Uop,
400  DistSubLeaf<scalar_t>& Phi, DistM_t& Vop,
401  long long int& flops) const override;
402 
403  static int Pl(std::size_t n, std::size_t nl, std::size_t nr, int P) {
404  return std::max
405  (1, std::min(int(std::round(float(P) * nl / n)), P-1));
406  }
407  static int Pr(std::size_t n, std::size_t nl, std::size_t nr, int P) {
408  return std::max(1, P - Pl(n, nl, nr, P));
409  }
410  int Pl() const {
411  return Pl(this->rows(), child(0)->rows(),
412  child(1)->rows(), Ptotal());
413  }
414  int Pr() const {
415  return Pr(this->rows(), child(0)->rows(),
416  child(1)->rows(), Ptotal());
417  }
418 
419  template<typename T> friend
420  void apply_HSS(Trans ta, const HSSMatrixMPI<T>& a,
421  const DistributedMatrix<T>& b, T beta,
422  DistributedMatrix<T>& c);
423  friend class DistSamples<scalar_t>;
424 
426  };
427 
428  } // end namespace HSS
429 } // end namespace strumpack
430 
431 #endif // HSS_MATRIX_MPI_HPP
This file contains the HSSMatrix class definition as well as implementations for a number of it's mem...
Definitions of several kernel functions, and helper routines. Also provides driver routines for kerne...
Contains some simple C++ MPI wrapper utilities.
This is a small wrapper class around a BLACS grid and a BLACS context.
Definition: BLACSGrid.hpp:66
int npactives() const
Definition: BLACSGrid.hpp:257
int P() const
Definition: BLACSGrid.hpp:251
const MPIComm & Comm() const
Definition: BLACSGrid.hpp:196
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition: DenseMatrix.hpp:1015
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition: DenseMatrix.hpp:138
Definition: DistributedMatrix.hpp:733
2D block cyclicly distributed matrix, as used by ScaLAPACK.
Definition: DistributedMatrix.hpp:84
Abstract base class for Hierarchically Semi-Separable (HSS) matrices.
Definition: HSSMatrixBase.hpp:83
const HSSMatrixBase< scalar_t > & child(int c) const
Definition: HSSMatrixBase.hpp:188
std::size_t rows() const override
Definition: HSSMatrixBase.hpp:163
Distributed memory implementation of the HSS (Hierarchically Semi-Separable) matrix format.
Definition: HSSMatrixMPI.hpp:68
void mult(Trans op, const DistM_t &x, DistM_t &y) const override
void shift(scalar_t sigma) override
std::size_t nonzeros() const override
std::size_t levels() const override
void solve(DistM_t &b) const override
std::unique_ptr< HSSMatrixBase< scalar_t > > clone() const override
void print_info(std::ostream &out=std::cout, std::size_t roff=0, std::size_t coff=0) const override
std::size_t rank() const override
std::size_t memory() const override
Class containing several options for the HSS code and data-structures.
Definition: HSSOptions.hpp:152
Definition: HSSExtraMPI.hpp:134
Wrapper class around an MPI_Comm object.
Definition: MPIWrapper.hpp:194
MPI_Comm comm() const
Definition: MPIWrapper.hpp:261
Representation of a kernel matrix.
Definition: Kernel.hpp:73
The cluster tree, or partition tree that represents the partitioning of the rows or columns of a hier...
Definition: ClusterTree.hpp:67
Definition: StrumpackOptions.hpp:43
Trans
Definition: DenseMatrix.hpp:51