28 #ifndef HSS_MATRIX_MPI_HPP
29 #define HSS_MATRIX_MPI_HPP
35 #include "HSSExtraMPI.hpp"
36 #include "DistSamples.hpp"
37 #include "DistElemMult.hpp"
38 #include "HSSBasisIDMPI.hpp"
44 template<
typename scalar_t>
class HSSMatrixBase;
45 template<
typename scalar_t>
class HSSMatrix;
47 template<
typename scalar_t>
48 class HSSMatrixMPI :
public HSSMatrixBase<scalar_t> {
49 using real_t =
typename RealType<scalar_t>::value_type;
50 using DistM_t = DistributedMatrix<scalar_t>;
51 using DistMW_t = DistributedMatrixWrapper<scalar_t>;
52 using DenseM_t = DenseMatrix<scalar_t>;
53 using DenseMW_t = DenseMatrixWrapper<scalar_t>;
54 using delem_t =
typename std::function
55 <void(
const std::vector<std::size_t>& I,
56 const std::vector<std::size_t>& J, DistM_t& B)>;
57 using delem_blocks_t =
typename std::function
58 <void(
const std::vector<std::vector<std::size_t>>& I,
59 const std::vector<std::vector<std::size_t>>& J,
60 std::vector<DistMW_t>& B)>;
61 using elem_t =
typename std::function
62 <void(
const std::vector<std::size_t>& I,
63 const std::vector<std::size_t>& J, DenseM_t& B)>;
64 using dmult_t =
typename std::function
65 <void(DistM_t& R, DistM_t& Sr, DistM_t& Sc)>;
66 using opts_t = HSSOptions<scalar_t>;
69 HSSMatrixMPI() : HSSMatrixBase<scalar_t>(0, 0, true) {}
70 HSSMatrixMPI(
const DistM_t& A,
const opts_t& opts);
72 (
const HSSPartitionTree& t,
const DistM_t& A,
const opts_t& opts);
74 (
const HSSPartitionTree& t,
const BLACSGrid* g,
const opts_t& opts);
76 (std::size_t m, std::size_t n,
const BLACSGrid* Agrid,
77 const dmult_t& Amult,
const delem_t& Aelem,
const opts_t& opts);
79 (std::size_t m, std::size_t n,
const BLACSGrid* Agrid,
80 const dmult_t& Amult,
const delem_blocks_t& Aelem,
const opts_t& opts);
82 (
const HSSPartitionTree& t,
const BLACSGrid* Agrid,
83 const dmult_t& Amult,
const delem_t& Aelem,
const opts_t& opts);
85 (kernel::Kernel<real_t>& K,
const BLACSGrid* Agrid,
const opts_t& opts);
86 HSSMatrixMPI(
const HSSMatrixMPI<scalar_t>& other);
87 HSSMatrixMPI(HSSMatrixMPI<scalar_t>&& other) =
default;
88 virtual ~HSSMatrixMPI() {}
90 HSSMatrixMPI<scalar_t>& operator=(
const HSSMatrixMPI<scalar_t>& other);
91 HSSMatrixMPI<scalar_t>& operator=(HSSMatrixMPI<scalar_t>&& other) =
default;
92 std::unique_ptr<HSSMatrixBase<scalar_t>> clone()
const override;
94 const HSSMatrixBase<scalar_t>* child(
int c)
const
95 {
return this->_ch[c].get(); }
96 HSSMatrixBase<scalar_t>* child(
int c) {
return this->_ch[c].get(); }
98 inline const BLACSGrid* grid()
const override {
return blacs_grid_; }
99 inline const BLACSGrid* grid(
const BLACSGrid* grid)
const override {
return blacs_grid_; }
100 inline const BLACSGrid* grid_local()
const override {
return blacs_grid_local_; }
101 inline const MPIComm& Comm()
const {
return grid()->
Comm(); }
102 inline MPI_Comm comm()
const {
return Comm().
comm(); }
103 int Ptotal()
const override {
return grid()->
P(); }
104 int Pactive()
const override {
return grid()->
npactives(); }
107 void compress(
const DistM_t& A,
const opts_t& opts);
109 (
const dmult_t& Amult,
const delem_t& Aelem,
const opts_t& opts);
111 (
const dmult_t& Amult,
const delem_blocks_t& Aelem,
const opts_t& opts);
113 (
const kernel::Kernel<real_t>& K,
const opts_t& opts);
115 HSSFactorsMPI<scalar_t> factor()
const;
116 HSSFactorsMPI<scalar_t> partial_factor()
const;
117 void solve(
const HSSFactorsMPI<scalar_t>& ULV, DistM_t& b)
const;
119 (
const HSSFactorsMPI<scalar_t>& ULV, WorkSolveMPI<scalar_t>& w,
120 const DistM_t& b,
bool partial)
const override;
122 (
const HSSFactorsMPI<scalar_t>& ULV,
123 WorkSolveMPI<scalar_t>& w, DistM_t& x)
const override;
125 DistM_t apply(
const DistM_t& b)
const;
126 DistM_t applyC(
const DistM_t& b)
const;
128 scalar_t get(std::size_t i, std::size_t j)
const;
130 (
const std::vector<std::size_t>& I,
131 const std::vector<std::size_t>& J,
const BLACSGrid* Bgrid)
const;
132 std::vector<DistM_t> extract
133 (
const std::vector<std::vector<std::size_t>>& I,
134 const std::vector<std::vector<std::size_t>>& J,
135 const BLACSGrid* Bgrid)
const;
137 (
const std::vector<std::size_t>& I,
138 const std::vector<std::size_t>& J, DistM_t& B)
const;
140 (
const std::vector<std::vector<std::size_t>>& I,
141 const std::vector<std::vector<std::size_t>>& J,
142 std::vector<DistM_t>& B)
const;
145 (
const HSSFactorsMPI<scalar_t>& f, DistM_t& Theta,
146 DistM_t& Vhat, DistM_t& DUB01, DistM_t& Phi)
const;
147 void Schur_product_direct
148 (
const DistM_t& Theta,
const DistM_t& Vhat,
const DistM_t& DUB01,
149 const DistM_t& Phi,
const DistM_t&_ThetaVhatC,
150 const DistM_t& VhatCPhiC,
const DistM_t& R,
151 DistM_t& Sr, DistM_t& Sc)
const;
153 std::size_t max_rank()
const;
154 std::size_t total_memory()
const;
155 std::size_t total_nonzeros()
const;
156 std::size_t max_levels()
const;
157 std::size_t rank()
const override;
158 std::size_t memory()
const override;
159 std::size_t nonzeros()
const override;
160 std::size_t levels()
const override;
163 (std::ostream &out=std::cout,
164 std::size_t roff=0, std::size_t coff=0)
const override;
166 DistM_t dense()
const;
168 void shift(scalar_t sigma)
override;
170 const TreeLocalRanges& tree_ranges()
const {
return _ranges; }
172 (
const DistM_t& A, DenseM_t& sub_A, DistM_t& leaf_A)
const override;
173 void allocate_block_row
174 (
int d, DenseM_t& sub_A, DistM_t& leaf_A)
const override;
176 (DistM_t& A,
const DenseM_t& sub_A,
const DistM_t& leaf_A,
177 const BLACSGrid* lgrid)
const override;
179 void delete_trailing_block()
override;
180 void reset()
override;
183 using delemw_t =
typename std::function
184 <void(
const std::vector<std::size_t>& I,
185 const std::vector<std::size_t>& J,
186 DistM_t& B, DistM_t& A,
187 std::size_t rlo, std::size_t clo,
190 const BLACSGrid* blacs_grid_;
191 const BLACSGrid* blacs_grid_local_;
192 std::unique_ptr<const BLACSGrid> owned_blacs_grid_;
193 std::unique_ptr<const BLACSGrid> owned_blacs_grid_local_;
195 TreeLocalRanges _ranges;
197 HSSBasisIDMPI<scalar_t> _U;
198 HSSBasisIDMPI<scalar_t> _V;
210 (std::size_t m, std::size_t n,
const opts_t& opts,
211 const MPIComm& c,
int P, std::size_t roff, std::size_t coff);
213 (
const HSSPartitionTree& t,
const opts_t& opts,
214 const MPIComm& c,
int P, std::size_t roff, std::size_t coff);
216 (
const opts_t& opts, std::size_t roff, std::size_t coff);
218 (
const HSSPartitionTree& t,
const opts_t& opts,
219 std::size_t roff, std::size_t coff);
220 void setup_local_context();
221 void setup_ranges(std::size_t roff, std::size_t coff);
223 void compress_original_nosync
224 (
const dmult_t& Amult,
const delemw_t& Aelem,
const opts_t& opts);
225 void compress_original_sync
226 (
const dmult_t& Amult,
const delemw_t& Aelem,
const opts_t& opts);
227 void compress_original_sync
228 (
const dmult_t& Amult,
const delem_blocks_t& Aelem,
const opts_t& opts);
229 void compress_stable_nosync
230 (
const dmult_t& Amult,
const delemw_t& Aelem,
const opts_t& opts);
231 void compress_stable_sync
232 (
const dmult_t& Amult,
const delemw_t& Aelem,
const opts_t& opts);
233 void compress_stable_sync
234 (
const dmult_t& Amult,
const delem_blocks_t& Aelem,
const opts_t& opts);
235 void compress_hard_restart_nosync
236 (
const dmult_t& Amult,
const delemw_t& Aelem,
const opts_t& opts);
237 void compress_hard_restart_sync
238 (
const dmult_t& Amult,
const delemw_t& Aelem,
const opts_t& opts);
239 void compress_hard_restart_sync
240 (
const dmult_t& Amult,
const delem_blocks_t& Aelem,
const opts_t& opts);
242 void compress_recursive_ann
243 (DenseMatrix<std::uint32_t>& ann, DenseMatrix<real_t>& scores,
244 const delemw_t& Aelem, WorkCompressMPIANN<scalar_t>& w,
245 const opts_t& opts,
const BLACSGrid* lg)
override;
246 void compute_local_samples_ann
247 (DenseMatrix<std::uint32_t>& ann, DenseMatrix<real_t>& scores,
248 WorkCompressMPIANN<scalar_t>& w,
const delemw_t& Aelem,
250 bool compute_U_V_bases_ann
251 (DistM_t& S,
const opts_t& opts, WorkCompressMPIANN<scalar_t>& w);
252 void communicate_child_data_ann(WorkCompressMPIANN<scalar_t>& w);
254 void compress_recursive_original
255 (DistSamples<scalar_t>& RS,
const delemw_t& Aelem,
256 const opts_t& opts, WorkCompressMPI<scalar_t>& w,
int dd)
override;
257 void compress_recursive_stable
258 (DistSamples<scalar_t>& RS,
const delemw_t& Aelem,
const opts_t& opts,
259 WorkCompressMPI<scalar_t>& w,
int d,
int dd)
override;
260 void compute_local_samples
261 (
const DistSamples<scalar_t>& RS, WorkCompressMPI<scalar_t>& w,
int dd);
262 bool compute_U_V_bases
263 (
int d,
const opts_t& opts, WorkCompressMPI<scalar_t>& w);
264 void compute_U_basis_stable
265 (
const opts_t& opts, WorkCompressMPI<scalar_t>& w,
int d,
int dd);
266 void compute_V_basis_stable
267 (
const opts_t& opts, WorkCompressMPI<scalar_t>& w,
int d,
int dd);
268 bool update_orthogonal_basis
269 (
const opts_t& opts, scalar_t& r_max_0,
const DistM_t& S,
270 DistM_t& Q,
int d,
int dd,
bool untouched,
int L);
271 void reduce_local_samples
272 (
const DistSamples<scalar_t>& RS, WorkCompressMPI<scalar_t>& w,
273 int dd,
bool was_compressed);
274 void communicate_child_data(WorkCompressMPI<scalar_t>& w);
275 void notify_inactives_J(WorkCompressMPI<scalar_t>& w);
276 void notify_inactives_J(WorkCompressMPIANN<scalar_t>& w);
277 void notify_inactives_states(WorkCompressMPI<scalar_t>& w);
279 void compress_level_original
280 (DistSamples<scalar_t>& RS,
const opts_t& opts,
281 WorkCompressMPI<scalar_t>& w,
int dd,
int lvl)
override;
282 void compress_level_stable
283 (DistSamples<scalar_t>& RS,
const opts_t& opts,
284 WorkCompressMPI<scalar_t>& w,
int d,
int dd,
int lvl)
override;
286 (
const delemw_t& Aelem,
const opts_t& opts,
287 WorkCompressMPI<scalar_t>& w,
int lvl);
289 (
const delem_blocks_t& Aelem,
const opts_t& opts,
290 WorkCompressMPI<scalar_t>& w,
int lvl);
291 void get_extraction_indices
292 (std::vector<std::vector<std::size_t>>& I,
293 std::vector<std::vector<std::size_t>>& J,
294 WorkCompressMPI<scalar_t>& w,
int&
self,
int lvl)
override;
295 void get_extraction_indices
296 (std::vector<std::vector<std::size_t>>& I,
297 std::vector<std::vector<std::size_t>>& J, std::vector<DistMW_t>& B,
298 const BLACSGrid* lg, WorkCompressMPI<scalar_t>& w,
299 int&
self,
int lvl)
override;
300 void allgather_extraction_indices
301 (std::vector<std::vector<std::size_t>>& lI,
302 std::vector<std::vector<std::size_t>>& lJ,
303 std::vector<std::vector<std::size_t>>& I,
304 std::vector<std::vector<std::size_t>>& J,
305 int& before,
int self,
int& after);
307 (
const delemw_t& Aelem,
const BLACSGrid* lg,
const opts_t& opts,
308 WorkCompressMPI<scalar_t>& w,
int lvl)
override;
310 void factor_recursive
311 (HSSFactorsMPI<scalar_t>& f, WorkFactorMPI<scalar_t>& w,
312 const BLACSGrid* lg,
bool isroot,
bool partial)
const override;
315 (
const HSSFactorsMPI<scalar_t>& ULV,
const DistSubLeaf<scalar_t>& b,
316 WorkSolveMPI<scalar_t>& w,
bool partial,
bool isroot)
const override;
318 (
const HSSFactorsMPI<scalar_t>& ULV, DistSubLeaf<scalar_t>& x,
319 WorkSolveMPI<scalar_t>& w,
bool isroot)
const override;
322 (
const DistSubLeaf<scalar_t>& B, WorkApplyMPI<scalar_t>& w,
323 bool isroot,
long long int flops)
const override;
325 (
const DistSubLeaf<scalar_t>& B, scalar_t beta,
326 DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
327 bool isroot,
long long int flops)
const override;
329 (
const DistSubLeaf<scalar_t>& B, WorkApplyMPI<scalar_t>& w,
330 bool isroot,
long long int flops)
const override;
332 (
const DistSubLeaf<scalar_t>& B, scalar_t beta,
333 DistSubLeaf<scalar_t>& C, WorkApplyMPI<scalar_t>& w,
334 bool isroot,
long long int flops)
const override;
337 (WorkExtractMPI<scalar_t>& w,
const BLACSGrid* lg,
338 bool odiag)
const override;
340 (std::vector<Triplet<scalar_t>>& triplets,
341 const BLACSGrid* lg, WorkExtractMPI<scalar_t>& w)
const override;
342 void triplets_to_DistM
343 (std::vector<Triplet<scalar_t>>& triplets, DistM_t& B)
const;
345 (WorkExtractBlocksMPI<scalar_t>& w,
const BLACSGrid* lg,
346 std::vector<bool>& odiag)
const override;
348 (std::vector<std::vector<Triplet<scalar_t>>>& triplets,
349 const BLACSGrid* lg, WorkExtractBlocksMPI<scalar_t>& w)
const override;
350 void triplets_to_DistM
351 (std::vector<std::vector<Triplet<scalar_t>>>& triplets,
352 std::vector<DistM_t>& B)
const;
354 void redistribute_to_tree_to_buffers
355 (
const DistM_t& A, std::size_t Arlo, std::size_t Aclo,
356 std::vector<std::vector<scalar_t>>& sbuf,
int dest=0)
override;
357 void redistribute_to_tree_from_buffers
358 (
const DistM_t& A, std::size_t rlo, std::size_t clo,
359 std::vector<scalar_t*>& pbuf)
override;
360 void delete_redistributed_input()
override;
363 (DistSubLeaf<scalar_t>& Theta, DistM_t& Uop,
364 DistSubLeaf<scalar_t>& Phi, DistM_t& Vop,
365 long long int& flops)
const override;
367 static int Pl(std::size_t n, std::size_t nl, std::size_t nr,
int P) {
369 (1, std::min(
int(std::round(
float(P) * nl / n)), P-1));
371 static int Pr(std::size_t n, std::size_t nl, std::size_t nr,
int P)
372 {
return std::max(1, P - Pl(n, nl, nr, P)); }
374 return Pl(this->rows(), this->_ch[0]->rows(),
375 this->_ch[1]->rows(), Ptotal());
378 return Pr(this->rows(), this->_ch[0]->rows(),
379 this->_ch[1]->rows(), Ptotal());
382 template<
typename T>
friend
383 void apply_HSS(
Trans ta,
const HSSMatrixMPI<T>& a,
384 const DistributedMatrix<T>& b, T beta,
385 DistributedMatrix<T>& c);
386 friend class DistSamples<scalar_t>;
392 #endif // HSS_MATRIX_MPI_HPP