HSSMatrixMPI.hpp
1 /*
2  * STRUMPACK -- STRUctured Matrices PACKage, Copyright (c) 2014, The
3  * Regents of the University of California, through Lawrence Berkeley
4  * National Laboratory (subject to receipt of any required approvals
5  * from the U.S. Dept. of Energy). All rights reserved.
6  *
7  * If you have questions about your rights to use or distribute this
8  * software, please contact Berkeley Lab's Technology Transfer
9  * Department at TTD@lbl.gov.
10  *
11  * NOTICE. This software is owned by the U.S. Department of Energy. As
12  * such, the U.S. Government has been granted for itself and others
13  * acting on its behalf a paid-up, nonexclusive, irrevocable,
14  * worldwide license in the Software to reproduce, prepare derivative
15  * works, and perform publicly and display publicly. Beginning five
16  * (5) years after the date permission to assert copyright is obtained
17  * from the U.S. Department of Energy, and subject to any subsequent
18  * five (5) year renewals, the U.S. Government is granted for itself
19  * and others acting on its behalf a paid-up, nonexclusive,
20  * irrevocable, worldwide license in the Software to reproduce,
21  * prepare derivative works, distribute copies to the public, perform
22  * publicly and display publicly, and to permit others to do so.
23  *
24  * Developers: Pieter Ghysels, Francois-Henry Rouet, Xiaoye S. Li.
25  * (Lawrence Berkeley National Lab, Computational Research
26  * Division).
27  */
28 #ifndef HSS_MATRIX_MPI_HPP
29 #define HSS_MATRIX_MPI_HPP
30 
31 #include <cassert>
32 
33 #include "HSSMatrix.hpp"
34 #include "misc/MPIWrapper.hpp"
35 #include "HSSExtraMPI.hpp"
36 #include "DistSamples.hpp"
37 #include "DistElemMult.hpp"
38 #include "HSSBasisIDMPI.hpp"
39 #include "kernel/Kernel.hpp"
40 
41 namespace strumpack {
42  namespace HSS {
43 
44  template<typename scalar_t> class HSSMatrixBase;
45  template<typename scalar_t> class HSSMatrix;
46 
47  template<typename scalar_t>
48  class HSSMatrixMPI : public HSSMatrixBase<scalar_t> {
49  using real_t = typename RealType<scalar_t>::value_type;
50  using DistM_t = DistributedMatrix<scalar_t>;
51  using DistMW_t = DistributedMatrixWrapper<scalar_t>;
52  using DenseM_t = DenseMatrix<scalar_t>;
53  using DenseMW_t = DenseMatrixWrapper<scalar_t>;
54  using delem_t = typename std::function
55  <void(const std::vector<std::size_t>& I,
56  const std::vector<std::size_t>& J, DistM_t& B)>;
57  using delem_blocks_t = typename std::function
58  <void(const std::vector<std::vector<std::size_t>>& I,
59  const std::vector<std::vector<std::size_t>>& J,
60  std::vector<DistMW_t>& B)>;
61  using elem_t = typename std::function
62  <void(const std::vector<std::size_t>& I,
63  const std::vector<std::size_t>& J, DenseM_t& B)>;
64  using dmult_t = typename std::function
65  <void(DistM_t& R, DistM_t& Sr, DistM_t& Sc)>;
66  using opts_t = HSSOptions<scalar_t>;
67 
68  public:
69  HSSMatrixMPI() : HSSMatrixBase<scalar_t>(0, 0, true) {}
70  HSSMatrixMPI(const DistM_t& A, const opts_t& opts);
71  HSSMatrixMPI
72  (const HSSPartitionTree& t, const DistM_t& A, const opts_t& opts);
73  HSSMatrixMPI
74  (const HSSPartitionTree& t, const BLACSGrid* g, const opts_t& opts);
75  HSSMatrixMPI
76  (std::size_t m, std::size_t n, const BLACSGrid* Agrid,
77  const dmult_t& Amult, const delem_t& Aelem, const opts_t& opts);
78  HSSMatrixMPI
79  (std::size_t m, std::size_t n, const BLACSGrid* Agrid,
80  const dmult_t& Amult, const delem_blocks_t& Aelem, const opts_t& opts);
81  HSSMatrixMPI
82  (const HSSPartitionTree& t, const BLACSGrid* Agrid,
83  const dmult_t& Amult, const delem_t& Aelem, const opts_t& opts);
84  HSSMatrixMPI
85  (kernel::Kernel<real_t>& K, const BLACSGrid* Agrid, const opts_t& opts);
86  HSSMatrixMPI(const HSSMatrixMPI<scalar_t>& other);
87  HSSMatrixMPI(HSSMatrixMPI<scalar_t>&& other) = default;
88  virtual ~HSSMatrixMPI() {}
89 
90  HSSMatrixMPI<scalar_t>& operator=(const HSSMatrixMPI<scalar_t>& other);
91  HSSMatrixMPI<scalar_t>& operator=(HSSMatrixMPI<scalar_t>&& other) = default;
92  std::unique_ptr<HSSMatrixBase<scalar_t>> clone() const override;
93 
94  const HSSMatrixBase<scalar_t>* child(int c) const
95  { return this->_ch[c].get(); }
96  HSSMatrixBase<scalar_t>* child(int c) { return this->_ch[c].get(); }
97 
98  inline const BLACSGrid* grid() const override { return blacs_grid_; }
99  inline const BLACSGrid* grid(const BLACSGrid* grid) const override { return blacs_grid_; }
100  inline const BLACSGrid* grid_local() const override { return blacs_grid_local_; }
101  inline const MPIComm& Comm() const { return grid()->Comm(); }
102  inline MPI_Comm comm() const { return Comm().comm(); }
103  int Ptotal() const override { return grid()->P(); }
104  int Pactive() const override { return grid()->npactives(); }
105 
106 
107  void compress(const DistM_t& A, const opts_t& opts);
108  void compress
109  (const dmult_t& Amult, const delem_t& Aelem, const opts_t& opts);
110  void compress
111  (const dmult_t& Amult, const delem_blocks_t& Aelem, const opts_t& opts);
112  void compress
113  (const kernel::Kernel<real_t>& K, const opts_t& opts);
114 
115  HSSFactorsMPI<scalar_t> factor() const;
116  HSSFactorsMPI<scalar_t> partial_factor() const;
117  void solve(const HSSFactorsMPI<scalar_t>& ULV, DistM_t& b) const;
118  void forward_solve
119  (const HSSFactorsMPI<scalar_t>& ULV, WorkSolveMPI<scalar_t>& w,
120  const DistM_t& b, bool partial) const override;
121  void backward_solve
122  (const HSSFactorsMPI<scalar_t>& ULV,
123  WorkSolveMPI<scalar_t>& w, DistM_t& x) const override;
124 
125  DistM_t apply(const DistM_t& b) const;
126  DistM_t applyC(const DistM_t& b) const;
127 
128  scalar_t get(std::size_t i, std::size_t j) const;
129  DistM_t extract
130  (const std::vector<std::size_t>& I,
131  const std::vector<std::size_t>& J, const BLACSGrid* Bgrid) const;
132  std::vector<DistM_t> extract
133  (const std::vector<std::vector<std::size_t>>& I,
134  const std::vector<std::vector<std::size_t>>& J,
135  const BLACSGrid* Bgrid) const;
136  void extract_add
137  (const std::vector<std::size_t>& I,
138  const std::vector<std::size_t>& J, DistM_t& B) const;
139  void extract_add
140  (const std::vector<std::vector<std::size_t>>& I,
141  const std::vector<std::vector<std::size_t>>& J,
142  std::vector<DistM_t>& B) const;
143 
144  void Schur_update
145  (const HSSFactorsMPI<scalar_t>& f, DistM_t& Theta,
146  DistM_t& Vhat, DistM_t& DUB01, DistM_t& Phi) const;
147  void Schur_product_direct
148  (const DistM_t& Theta, const DistM_t& Vhat, const DistM_t& DUB01,
149  const DistM_t& Phi, const DistM_t&_ThetaVhatC,
150  const DistM_t& VhatCPhiC, const DistM_t& R,
151  DistM_t& Sr, DistM_t& Sc) const;
152 
153  std::size_t max_rank() const; // collective on comm()
154  std::size_t total_memory() const; // collective on comm()
155  std::size_t total_nonzeros() const; // collective on comm()
156  std::size_t max_levels() const; // collective on comm()
157  std::size_t rank() const override;
158  std::size_t memory() const override;
159  std::size_t nonzeros() const override;
160  std::size_t levels() const override;
161 
162  void print_info
163  (std::ostream &out=std::cout,
164  std::size_t roff=0, std::size_t coff=0) const override;
165 
166  DistM_t dense() const;
167 
168  void shift(scalar_t sigma) override;
169 
170  const TreeLocalRanges& tree_ranges() const { return _ranges; }
171  void to_block_row
172  (const DistM_t& A, DenseM_t& sub_A, DistM_t& leaf_A) const override;
173  void allocate_block_row
174  (int d, DenseM_t& sub_A, DistM_t& leaf_A) const override;
175  void from_block_row
176  (DistM_t& A, const DenseM_t& sub_A, const DistM_t& leaf_A,
177  const BLACSGrid* lgrid) const override;
178 
179  void delete_trailing_block() override;
180  void reset() override;
181 
182  private:
183  using delemw_t = typename std::function
184  <void(const std::vector<std::size_t>& I,
185  const std::vector<std::size_t>& J,
186  DistM_t& B, DistM_t& A,
187  std::size_t rlo, std::size_t clo,
188  MPI_Comm comm)>;
189 
190  const BLACSGrid* blacs_grid_;
191  const BLACSGrid* blacs_grid_local_;
192  std::unique_ptr<const BLACSGrid> owned_blacs_grid_;
193  std::unique_ptr<const BLACSGrid> owned_blacs_grid_local_;
194 
195  TreeLocalRanges _ranges;
196 
197  HSSBasisIDMPI<scalar_t> _U;
198  HSSBasisIDMPI<scalar_t> _V;
199  DistM_t _D;
200  DistM_t _B01;
201  DistM_t _B10;
202 
203  // Used to redistribute the original 2D block cyclic matrix
204  // according to the HSS tree
205  DistM_t _A;
206  DistM_t _A01;
207  DistM_t _A10;
208 
209  HSSMatrixMPI
210  (std::size_t m, std::size_t n, const opts_t& opts,
211  const MPIComm& c, int P, std::size_t roff, std::size_t coff);
212  HSSMatrixMPI
213  (const HSSPartitionTree& t, const opts_t& opts,
214  const MPIComm& c, int P, std::size_t roff, std::size_t coff);
215  void setup_hierarchy
216  (const opts_t& opts, std::size_t roff, std::size_t coff);
217  void setup_hierarchy
218  (const HSSPartitionTree& t, const opts_t& opts,
219  std::size_t roff, std::size_t coff);
220  void setup_local_context();
221  void setup_ranges(std::size_t roff, std::size_t coff);
222 
223  void compress_original_nosync
224  (const dmult_t& Amult, const delemw_t& Aelem, const opts_t& opts);
225  void compress_original_sync
226  (const dmult_t& Amult, const delemw_t& Aelem, const opts_t& opts);
227  void compress_original_sync
228  (const dmult_t& Amult, const delem_blocks_t& Aelem, const opts_t& opts);
229  void compress_stable_nosync
230  (const dmult_t& Amult, const delemw_t& Aelem, const opts_t& opts);
231  void compress_stable_sync
232  (const dmult_t& Amult, const delemw_t& Aelem, const opts_t& opts);
233  void compress_stable_sync
234  (const dmult_t& Amult, const delem_blocks_t& Aelem, const opts_t& opts);
235  void compress_hard_restart_nosync
236  (const dmult_t& Amult, const delemw_t& Aelem, const opts_t& opts);
237  void compress_hard_restart_sync
238  (const dmult_t& Amult, const delemw_t& Aelem, const opts_t& opts);
239  void compress_hard_restart_sync
240  (const dmult_t& Amult, const delem_blocks_t& Aelem, const opts_t& opts);
241 
242  void compress_recursive_ann
243  (DenseMatrix<std::uint32_t>& ann, DenseMatrix<real_t>& scores,
244  const delemw_t& Aelem, WorkCompressMPIANN<scalar_t>& w,
245  const opts_t& opts, const BLACSGrid* lg) override;
246  void compute_local_samples_ann
247  (DenseMatrix<std::uint32_t>& ann, DenseMatrix<real_t>& scores,
248  WorkCompressMPIANN<scalar_t>& w, const delemw_t& Aelem,
249  const opts_t& opts);
250  bool compute_U_V_bases_ann
251  (DistM_t& S, const opts_t& opts, WorkCompressMPIANN<scalar_t>& w);
252  void communicate_child_data_ann(WorkCompressMPIANN<scalar_t>& w);
253 
254  void compress_recursive_original
255  (DistSamples<scalar_t>& RS, const delemw_t& Aelem,
256  const opts_t& opts, WorkCompressMPI<scalar_t>& w, int dd) override;
257  void compress_recursive_stable
258  (DistSamples<scalar_t>& RS, const delemw_t& Aelem, const opts_t& opts,
259  WorkCompressMPI<scalar_t>& w, int d, int dd) override;
260  void compute_local_samples
261  (const DistSamples<scalar_t>& RS, WorkCompressMPI<scalar_t>& w, int dd);
262  bool compute_U_V_bases
263  (int d, const opts_t& opts, WorkCompressMPI<scalar_t>& w);
264  void compute_U_basis_stable
265  (const opts_t& opts, WorkCompressMPI<scalar_t>& w, int d, int dd);
266  void compute_V_basis_stable
267  (const opts_t& opts, WorkCompressMPI<scalar_t>& w, int d, int dd);
268  bool update_orthogonal_basis
269  (const opts_t& opts, scalar_t& r_max_0, const DistM_t& S,
270  DistM_t& Q, int d, int dd, bool untouched, int L);
271  void reduce_local_samples
272  (const DistSamples<scalar_t>& RS, WorkCompressMPI<scalar_t>& w,
273  int dd, bool was_compressed);
274  void communicate_child_data(WorkCompressMPI<scalar_t>& w);
275  void notify_inactives_J(WorkCompressMPI<scalar_t>& w);
276  void notify_inactives_J(WorkCompressMPIANN<scalar_t>& w);
277  void notify_inactives_states(WorkCompressMPI<scalar_t>& w);
278 
279  void compress_level_original
280  (DistSamples<scalar_t>& RS, const opts_t& opts,
281  WorkCompressMPI<scalar_t>& w, int dd, int lvl) override;
282  void compress_level_stable
283  (DistSamples<scalar_t>& RS, const opts_t& opts,
284  WorkCompressMPI<scalar_t>& w, int d, int dd, int lvl) override;
285  void extract_level
286  (const delemw_t& Aelem, const opts_t& opts,
287  WorkCompressMPI<scalar_t>& w, int lvl);
288  void extract_level
289  (const delem_blocks_t& Aelem, const opts_t& opts,
290  WorkCompressMPI<scalar_t>& w, int lvl);
291  void get_extraction_indices
292  (std::vector<std::vector<std::size_t>>& I,
293  std::vector<std::vector<std::size_t>>& J,
294  WorkCompressMPI<scalar_t>& w, int& self, int lvl) override;
295  void get_extraction_indices
296  (std::vector<std::vector<std::size_t>>& I,
297  std::vector<std::vector<std::size_t>>& J, std::vector<DistMW_t>& B,
298  const BLACSGrid* lg, WorkCompressMPI<scalar_t>& w,
299  int& self, int lvl) override;
300  void allgather_extraction_indices
301  (std::vector<std::vector<std::size_t>>& lI,
302  std::vector<std::vector<std::size_t>>& lJ,
303  std::vector<std::vector<std::size_t>>& I,
304  std::vector<std::vector<std::size_t>>& J,
305  int& before, int self, int& after);
306  void extract_D_B
307  (const delemw_t& Aelem, const BLACSGrid* lg, const opts_t& opts,
308  WorkCompressMPI<scalar_t>& w, int lvl) override;
309 
310  void factor_recursive
311  (HSSFactorsMPI<scalar_t>& f, WorkFactorMPI<scalar_t>& w,
312  const BLACSGrid* lg, bool isroot, bool partial) const override;
313 
314  void solve_fwd
315  (const HSSFactorsMPI<scalar_t>& ULV, const DistSubLeaf<scalar_t>& b,
316  WorkSolveMPI<scalar_t>& w, bool partial, bool isroot) const override;
317  void solve_bwd
318  (const HSSFactorsMPI<scalar_t>& ULV, DistSubLeaf<scalar_t>& x,
319  WorkSolveMPI<scalar_t>& w, bool isroot) const override;
320 
321  void apply_fwd
322  (const DistSubLeaf<scalar_t>& B, WorkApplyMPI<scalar_t>& w,
323  bool isroot, long long int flops) const override;
324  void apply_bwd
325  (const DistSubLeaf<scalar_t>& B, scalar_t beta,
326  DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
327  bool isroot, long long int flops) const override;
328  void applyT_fwd
329  (const DistSubLeaf<scalar_t>& B, WorkApplyMPI<scalar_t>& w,
330  bool isroot, long long int flops) const override;
331  void applyT_bwd
332  (const DistSubLeaf<scalar_t>& B, scalar_t beta,
333  DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
334  bool isroot, long long int flops) const override;
335 
336  void extract_fwd
337  (WorkExtractMPI<scalar_t>& w, const BLACSGrid* lg,
338  bool odiag) const override;
339  void extract_bwd
340  (std::vector<Triplet<scalar_t>>& triplets,
341  const BLACSGrid* lg, WorkExtractMPI<scalar_t>& w) const override;
342  void triplets_to_DistM
343  (std::vector<Triplet<scalar_t>>& triplets, DistM_t& B) const;
344  void extract_fwd
345  (WorkExtractBlocksMPI<scalar_t>& w, const BLACSGrid* lg,
346  std::vector<bool>& odiag) const override;
347  void extract_bwd
348  (std::vector<std::vector<Triplet<scalar_t>>>& triplets,
349  const BLACSGrid* lg, WorkExtractBlocksMPI<scalar_t>& w) const override;
350  void triplets_to_DistM
351  (std::vector<std::vector<Triplet<scalar_t>>>& triplets,
352  std::vector<DistM_t>& B) const;
353 
354  void redistribute_to_tree_to_buffers
355  (const DistM_t& A, std::size_t Arlo, std::size_t Aclo,
356  std::vector<std::vector<scalar_t>>& sbuf, int dest=0) override;
357  void redistribute_to_tree_from_buffers
358  (const DistM_t& A, std::size_t rlo, std::size_t clo,
359  std::vector<scalar_t*>& pbuf) override;
360  void delete_redistributed_input() override;
361 
362  void apply_UV_big
363  (DistSubLeaf<scalar_t>& Theta, DistM_t& Uop,
364  DistSubLeaf<scalar_t>& Phi, DistM_t& Vop,
365  long long int& flops) const override;
366 
367  static int Pl(std::size_t n, std::size_t nl, std::size_t nr, int P) {
368  return std::max
369  (1, std::min(int(std::round(float(P) * nl / n)), P-1));
370  }
371  static int Pr(std::size_t n, std::size_t nl, std::size_t nr, int P)
372  { return std::max(1, P - Pl(n, nl, nr, P)); }
373  int Pl() const {
374  return Pl(this->rows(), this->_ch[0]->rows(),
375  this->_ch[1]->rows(), Ptotal());
376  }
377  int Pr() const {
378  return Pr(this->rows(), this->_ch[0]->rows(),
379  this->_ch[1]->rows(), Ptotal());
380  }
381 
382  template<typename T> friend
383  void apply_HSS(Trans ta, const HSSMatrixMPI<T>& a,
384  const DistributedMatrix<T>& b, T beta,
385  DistributedMatrix<T>& c);
386  friend class DistSamples<scalar_t>;
387  };
388 
389  } // end namespace HSS
390 } // end namespace strumpack
391 
392 #endif // HSS_MATRIX_MPI_HPP
strumpack::BLACSGrid::npactives
int npactives() const
Definition: BLACSGrid.hpp:257
HSSMatrix.hpp
This file contains the HSSMatrix class definition as well as implementations for a number of it's mem...
strumpack
Definition: StrumpackOptions.hpp:42
Kernel.hpp
Definitions of several kernel functions, and helper routines. Also provides driver routines for kerne...
strumpack::MPIComm::comm
MPI_Comm comm() const
Definition: MPIWrapper.hpp:257
strumpack::BLACSGrid::P
int P() const
Definition: BLACSGrid.hpp:251
strumpack::CompressionType::HSS
@ HSS
MPIWrapper.hpp
Contains some simple C++ MPI wrapper utilities.
strumpack::Trans
Trans
Definition: DenseMatrix.hpp:50
strumpack::BLACSGrid::Comm
const MPIComm & Comm() const
Definition: BLACSGrid.hpp:196