69 using real_t =
typename RealType<scalar_t>::value_type;
74 using delem_t =
typename std::function
75 <void(
const std::vector<std::size_t>& I,
76 const std::vector<std::size_t>& J,
DistM_t& B)>;
77 using delem_blocks_t =
typename std::function
78 <void(
const std::vector<std::vector<std::size_t>>& I,
79 const std::vector<std::vector<std::size_t>>& J,
80 std::vector<DistMW_t>& B)>;
81 using elem_t =
typename std::function
82 <void(
const std::vector<std::size_t>& I,
83 const std::vector<std::size_t>& J,
DenseM_t& B)>;
84 using dmult_t =
typename std::function
96 const dmult_t& Amult,
const delem_t& Aelem,
99 const dmult_t& Amult,
const delem_blocks_t& Aelem,
102 const dmult_t& Amult,
const delem_t& Aelem,
112 std::unique_ptr<HSSMatrixBase<scalar_t>>
clone()
const override;
115 return this->ch_[c].get();
119 const BLACSGrid* grid()
const override {
return blacs_grid_; }
120 const BLACSGrid* grid(
const BLACSGrid* grid)
const override {
return blacs_grid_; }
121 const BLACSGrid* grid_local()
const override {
return blacs_grid_local_; }
122 const MPIComm& Comm()
const {
return grid()->
Comm(); }
123 MPI_Comm comm()
const {
return Comm().
comm(); }
124 int Ptotal()
const override {
return grid()->
P(); }
125 int Pactive()
const override {
return grid()->
npactives(); }
128 void compress(
const DistM_t& A,
130 void compress(
const dmult_t& Amult,
131 const delem_t& Aelem,
133 void compress(
const dmult_t& Amult,
134 const delem_blocks_t& Aelem,
136 void compress(
const kernel::Kernel<real_t>& K,
const opts_t& opts);
139 void partial_factor();
142 void forward_solve(WorkSolveMPI<scalar_t>& w,
const DistM_t& b,
143 bool partial)
const override;
144 void backward_solve(WorkSolveMPI<scalar_t>& w,
152 scalar_t get(std::size_t i, std::size_t j)
const;
153 DistM_t extract(
const std::vector<std::size_t>& I,
154 const std::vector<std::size_t>& J,
157 extract(
const std::vector<std::vector<std::size_t>>& I,
158 const std::vector<std::vector<std::size_t>>& J,
160 void extract_add(
const std::vector<std::size_t>& I,
161 const std::vector<std::size_t>& J,
DistM_t& B)
const;
162 void extract_add(
const std::vector<std::vector<std::size_t>>& I,
163 const std::vector<std::vector<std::size_t>>& J,
164 std::vector<DistM_t>& B)
const;
168 void Schur_product_direct(
const DistM_t& Theta,
177 std::size_t max_rank()
const;
178 std::size_t total_memory()
const;
179 std::size_t total_nonzeros()
const;
180 std::size_t total_factor_nonzeros()
const;
181 std::size_t max_levels()
const;
182 std::size_t
rank()
const override;
185 std::size_t factor_nonzeros()
const override;
190 std::size_t coff=0)
const override;
194 void shift(scalar_t sigma)
override;
197 void to_block_row(
const DistM_t& A,
199 DistM_t& leaf_A)
const override;
200 void allocate_block_row(
int d, DenseM_t& sub_A,
201 DistM_t& leaf_A)
const override;
202 void from_block_row(DistM_t& A,
203 const DenseM_t& sub_A,
204 const DistM_t& leaf_A,
207 void delete_trailing_block()
override;
208 void reset()
override;
213 using delemw_t =
typename std::function
214 <void(
const std::vector<std::size_t>& I,
215 const std::vector<std::size_t>& J,
216 DistM_t& B, DistM_t& A,
217 std::size_t rlo, std::size_t clo,
222 std::unique_ptr<const BLACSGrid> owned_blacs_grid_;
223 std::unique_ptr<const BLACSGrid> owned_blacs_grid_local_;
225 TreeLocalRanges ranges_;
227 HSSBasisIDMPI<scalar_t> U_, V_;
228 DistM_t D_, B01_, B10_;
232 DistM_t A_, A01_, A10_;
234 HSSMatrixMPI(std::size_t m, std::size_t n,
const opts_t& opts,
236 std::size_t roff, std::size_t coff);
239 std::size_t roff, std::size_t coff);
240 void setup_hierarchy(
const opts_t& opts,
241 std::size_t roff, std::size_t coff);
243 std::size_t roff, std::size_t coff);
244 void setup_local_context();
245 void setup_ranges(std::size_t roff, std::size_t coff);
247 void compress_original_nosync(
const dmult_t& Amult,
248 const delemw_t& Aelem,
250 void compress_original_sync(
const dmult_t& Amult,
251 const delemw_t& Aelem,
253 void compress_original_sync(
const dmult_t& Amult,
254 const delem_blocks_t& Aelem,
256 void compress_stable_nosync(
const dmult_t& Amult,
257 const delemw_t& Aelem,
259 void compress_stable_sync(
const dmult_t& Amult,
260 const delemw_t& Aelem,
262 void compress_stable_sync(
const dmult_t& Amult,
263 const delem_blocks_t& Aelem,
265 void compress_hard_restart_nosync(
const dmult_t& Amult,
266 const delemw_t& Aelem,
268 void compress_hard_restart_sync(
const dmult_t& Amult,
269 const delemw_t& Aelem,
271 void compress_hard_restart_sync(
const dmult_t& Amult,
272 const delem_blocks_t& Aelem,
277 const delemw_t& Aelem,
278 WorkCompressMPIANN<scalar_t>& w,
283 WorkCompressMPIANN<scalar_t>& w,
284 const delemw_t& Aelem,
286 bool compute_U_V_bases_ann(DistM_t& S,
const opts_t& opts,
287 WorkCompressMPIANN<scalar_t>& w);
288 void communicate_child_data_ann(WorkCompressMPIANN<scalar_t>& w);
290 void compress_recursive_original(DistSamples<scalar_t>& RS,
291 const delemw_t& Aelem,
293 WorkCompressMPI<scalar_t>& w,
295 void compress_recursive_stable(DistSamples<scalar_t>& RS,
296 const delemw_t& Aelem,
298 WorkCompressMPI<scalar_t>& w,
299 int d,
int dd)
override;
300 void compute_local_samples(
const DistSamples<scalar_t>& RS,
301 WorkCompressMPI<scalar_t>& w,
int dd);
302 bool compute_U_V_bases(
int d,
const opts_t& opts,
303 WorkCompressMPI<scalar_t>& w);
304 void compute_U_basis_stable(
const opts_t& opts,
305 WorkCompressMPI<scalar_t>& w,
307 void compute_V_basis_stable(
const opts_t& opts,
308 WorkCompressMPI<scalar_t>& w,
310 bool update_orthogonal_basis(
const opts_t& opts,
311 scalar_t& r_max_0,
const DistM_t& S,
312 DistM_t& Q,
int d,
int dd,
313 bool untouched,
int L);
314 void reduce_local_samples(
const DistSamples<scalar_t>& RS,
315 WorkCompressMPI<scalar_t>& w,
316 int dd,
bool was_compressed);
317 void communicate_child_data(WorkCompressMPI<scalar_t>& w);
318 void notify_inactives_J(WorkCompressMPI<scalar_t>& w);
319 void notify_inactives_J(WorkCompressMPIANN<scalar_t>& w);
320 void notify_inactives_states(WorkCompressMPI<scalar_t>& w);
322 void compress_level_original(DistSamples<scalar_t>& RS,
324 WorkCompressMPI<scalar_t>& w,
325 int dd,
int lvl)
override;
326 void compress_level_stable(DistSamples<scalar_t>& RS,
328 WorkCompressMPI<scalar_t>& w,
329 int d,
int dd,
int lvl)
override;
330 void extract_level(
const delemw_t& Aelem,
const opts_t& opts,
331 WorkCompressMPI<scalar_t>& w,
int lvl);
332 void extract_level(
const delem_blocks_t& Aelem,
const opts_t& opts,
333 WorkCompressMPI<scalar_t>& w,
int lvl);
335 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
336 std::vector<std::vector<std::size_t>>& J,
337 WorkCompressMPI<scalar_t>& w,
338 int& self,
int lvl)
override;
339 void get_extraction_indices(std::vector<std::vector<std::size_t>>& I,
340 std::vector<std::vector<std::size_t>>& J,
341 std::vector<DistMW_t>& B,
343 WorkCompressMPI<scalar_t>& w,
344 int& self,
int lvl)
override;
345 void allgather_extraction_indices(std::vector<std::vector<std::size_t>>& lI,
346 std::vector<std::vector<std::size_t>>& lJ,
347 std::vector<std::vector<std::size_t>>& I,
348 std::vector<std::vector<std::size_t>>& J,
349 int& before,
int self,
int& after);
351 void extract_D_B(
const delemw_t& Aelem,
353 WorkCompressMPI<scalar_t>& w,
int lvl)
override;
355 void factor_recursive(WorkFactorMPI<scalar_t>& w,
357 bool isroot,
bool partial)
override;
359 void solve_fwd(
const DistSubLeaf<scalar_t>& b,
360 WorkSolveMPI<scalar_t>& w,
361 bool partial,
bool isroot)
const override;
362 void solve_bwd(DistSubLeaf<scalar_t>& x,
363 WorkSolveMPI<scalar_t>& w,
bool isroot)
const override;
365 void apply_fwd(
const DistSubLeaf<scalar_t>& B,
366 WorkApplyMPI<scalar_t>& w,
367 bool isroot,
long long int flops)
const override;
368 void apply_bwd(
const DistSubLeaf<scalar_t>& B, scalar_t beta,
369 DistSubLeaf<scalar_t>&
C, WorkApplyMPI<scalar_t>& w,
370 bool isroot,
long long int flops)
const override;
371 void applyT_fwd(
const DistSubLeaf<scalar_t>& B,
372 WorkApplyMPI<scalar_t>& w,
373 bool isroot,
long long int flops)
const override;
374 void applyT_bwd(
const DistSubLeaf<scalar_t>& B, scalar_t beta,
375 DistSubLeaf<scalar_t>&
C, WorkApplyMPI<scalar_t>& w,
376 bool isroot,
long long int flops)
const override;
378 void extract_fwd(WorkExtractMPI<scalar_t>& w,
const BLACSGrid* lg,
379 bool odiag)
const override;
380 void extract_bwd(std::vector<Triplet<scalar_t>>& triplets,
382 WorkExtractMPI<scalar_t>& w)
const override;
383 void triplets_to_DistM(std::vector<Triplet<scalar_t>>& triplets,
385 void extract_fwd(WorkExtractBlocksMPI<scalar_t>& w,
387 std::vector<bool>& odiag)
const override;
388 void extract_bwd(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
390 WorkExtractBlocksMPI<scalar_t>& w)
const override;
391 void triplets_to_DistM(std::vector<std::vector<Triplet<scalar_t>>>& triplets,
392 std::vector<DistM_t>& B)
const;
394 void redistribute_to_tree_to_buffers(
const DistM_t& A,
395 std::size_t Arlo, std::size_t Aclo,
396 std::vector<std::vector<scalar_t>>& sbuf,
397 int dest=0)
override;
398 void redistribute_to_tree_from_buffers(
const DistM_t& A,
399 std::size_t rlo, std::size_t clo,
400 std::vector<scalar_t*>& pbuf)
402 void delete_redistributed_input()
override;
404 void apply_UV_big(DistSubLeaf<scalar_t>& Theta, DistM_t& Uop,
405 DistSubLeaf<scalar_t>& Phi, DistM_t& Vop,
406 long long int& flops)
const override;
408 static int Pl(std::size_t n, std::size_t nl, std::size_t nr,
int P) {
410 (1, std::min(
int(std::round(
float(P) * nl / n)), P-1));
412 static int Pr(std::size_t n, std::size_t nl, std::size_t nr,
int P) {
413 return std::max(1, P - Pl(n, nl, nr, P));
416 return Pl(this->
rows(), child(0)->
rows(),
417 child(1)->
rows(), Ptotal());
420 return Pr(this->
rows(), child(0)->
rows(),
421 child(1)->
rows(), Ptotal());
424 template<
typename T>
friend
425 void apply_HSS(
Trans ta,
const HSSMatrixMPI<T>& a,
426 const DistributedMatrix<T>& b,
T beta,
427 DistributedMatrix<T>& c);
428 friend class DistSamples<scalar_t>;
433 using structured::StructuredMatrix<scalar_t>
::mult;
434 using structured::StructuredMatrix<scalar_t>
::solve;