Loading...
Searching...
No Matches
CSRMatrixMPI.hpp
Go to the documentation of this file.
1/*
2 * STRUMPACK -- STRUctured Matrices PACKage, Copyright (c) 2014, The
3 * Regents of the University of California, through Lawrence Berkeley
4 * National Laboratory (subject to receipt of any required approvals
5 * from the U.S. Dept. of Energy). All rights reserved.
6 *
7 * If you have questions about your rights to use or distribute this
8 * software, please contact Berkeley Lab's Technology Transfer
9 * Department at TTD@lbl.gov.
10 *
11 * NOTICE. This software is owned by the U.S. Department of Energy. As
12 * such, the U.S. Government has been granted for itself and others
13 * acting on its behalf a paid-up, nonexclusive, irrevocable,
14 * worldwide license in the Software to reproduce, prepare derivative
15 * works, and perform publicly and display publicly. Beginning five
16 * (5) years after the date permission to assert copyright is obtained
17 * from the U.S. Department of Energy, and subject to any subsequent
18 * five (5) year renewals, the U.S. Government is granted for itself
19 * and others acting on its behalf a paid-up, nonexclusive,
20 * irrevocable, worldwide license in the Software to reproduce,
21 * prepare derivative works, distribute copies to the public, perform
22 * publicly and display publicly, and to permit others to do so.
23 *
24 * Developers: Pieter Ghysels, Francois-Henry Rouet, Xiaoye S. Li.
25 * (Lawrence Berkeley National Lab, Computational Research
26 * Division).
27 */
33#ifndef STRUMPACK_CSRMATRIX_MPI_HPP
34#define STRUMPACK_CSRMATRIX_MPI_HPP
35
36#include <vector>
37#include <tuple>
38#include <memory>
39
40#include "CSRMatrix.hpp"
41#include "misc/MPIWrapper.hpp"
42#include "dense/BLASLAPACKWrapper.hpp"
43#include "CSRGraph.hpp"
44
45
46namespace strumpack {
47
48 template<typename scalar_t,typename integer_t> class SPMVBuffers {
49 public:
50 bool initialized = false;
51 std::vector<integer_t> sranks, rranks, soff, roffs, sind;
52 std::vector<scalar_t> sbuf, rbuf;
53 // for each off-diagonal entry spmv_prbuf stores the
54 // corresponding index in the receive buffer
55 std::vector<integer_t> prbuf;
56 };
57
58
71 template<typename scalar_t,typename integer_t>
72 class CSRMatrixMPI : public CompressedSparseMatrix<scalar_t,integer_t> {
76 using real_t = typename RealType<scalar_t>::value_type;
79
80 public:
82 CSRMatrixMPI(integer_t local_rows, const integer_t* row_ptr,
83 const integer_t* col_ind, const scalar_t* values,
84 const integer_t* dist, MPIComm comm, bool symm_sparse);
85 CSRMatrixMPI(integer_t rows, integer_t local_rows, integer_t local_nnz,
86 const integer_t* dist, MPIComm comm, bool symm_sparse);
87 CSRMatrixMPI(integer_t lrows, const integer_t* d_ptr,
88 const integer_t* d_ind, const scalar_t* d_val,
89 const integer_t* o_ptr, const integer_t* o_ind,
90 const scalar_t* o_val, const integer_t* garray,
91 MPIComm comm, bool symm_sparse=false);
93 bool only_at_root);
94
95 integer_t local_nnz() const { return lnnz_; }
96 integer_t local_rows() const { return lrows_; }
97 integer_t begin_row() const { return brow_; }
98 integer_t end_row() const { return brow_ + lrows_; }
99
100 const MPIComm& Comm() const { return comm_; }
101 MPI_Comm comm() const { return comm_.comm(); }
102
103 const std::vector<integer_t>& dist() const { return dist_; }
104 const integer_t& dist(std::size_t p) const {
105 assert(p < dist_.size());
106 return dist_[p];
107 }
108
109 real_t norm1() const override;
110
111 void spmv(const DenseM_t& x, DenseM_t& y) const override;
112 void spmv(const scalar_t* x, scalar_t* y) const override;
113
114 void permute(const integer_t* iorder, const integer_t* order) override;
115
116 std::unique_ptr<CSRMatrix<scalar_t,integer_t>> gather() const;
117 std::unique_ptr<CSRGraph<integer_t>> gather_graph() const;
118
119
134 Match_t matching(MatchingJob job, bool apply=true) override;
135
136 Equil_t equilibration() const override;
137
138 void equilibrate(const Equil_t&) override;
139
140 void permute_columns(const std::vector<integer_t>& perm) override;
141
142 void symmetrize_sparsity() override;
143
144 int read_matrix_market(const std::string& filename) override;
145
146 real_t max_scaled_residual(const DenseM_t& x, const DenseM_t& b)
147 const override;
148
149 real_t max_scaled_residual(const scalar_t* x, const scalar_t* b)
150 const override;
151
152 std::unique_ptr<CSRMatrixMPI<scalar_t,integer_t>>
153 add_missing_diagonal(const scalar_t& s) const;
154
156 get_sub_graph(const std::vector<integer_t>& perm,
157 const std::vector<std::pair<integer_t,integer_t>>&
158 graph_ranges) const;
159
161 extract_graph(int ordering_level,
162 integer_t lo, integer_t hi) const override {
163 assert(false);
164 return CSRGraph<integer_t>();
165 }
166
167 void print() const override;
168 void print_dense(const std::string& name) const override;
169 void print_matrix_market(const std::string& filename) const override;
170 void check() const;
171
172
173#ifndef DOXYGEN_SHOULD_SKIP_THIS
174 // implement outside of this class
175 void extract_separator
176 (integer_t, const std::vector<std::size_t>&,
177 const std::vector<std::size_t>&, DenseM_t&, int) const override {}
178 void extract_separator_2d
179 (integer_t, const std::vector<std::size_t>&,
180 const std::vector<std::size_t>&, DistM_t&) const override {}
181 void extract_front
182 (DenseM_t&, DenseM_t&, DenseM_t&, integer_t,
183 integer_t, const std::vector<integer_t>&, int) const override {}
184 void push_front_elements
185 (integer_t, integer_t, const std::vector<integer_t>&,
186 std::vector<Triplet<scalar_t>>&, std::vector<Triplet<scalar_t>>&,
187 std::vector<Triplet<scalar_t>>&) const override {}
188 void set_front_elements
189 (integer_t, integer_t, const std::vector<integer_t>&,
190 Triplet<scalar_t>*, Triplet<scalar_t>*,
191 Triplet<scalar_t>*) const override {}
192 void count_front_elements
193 (integer_t, integer_t, const std::vector<integer_t>&,
194 std::size_t&, std::size_t&, std::size_t&) const override {}
195
196 void extract_F11_block
197 (scalar_t*, integer_t, integer_t, integer_t,
198 integer_t, integer_t) const override {}
199 void extract_F12_block
200 (scalar_t*, integer_t, integer_t, integer_t, integer_t,
201 integer_t, const integer_t*) const override {}
202 void extract_F21_block
203 (scalar_t*, integer_t, integer_t, integer_t, integer_t,
204 integer_t, const integer_t*) const override {}
205 void front_multiply
206 (integer_t, integer_t, const std::vector<integer_t>&,
207 const DenseM_t&, DenseM_t&, DenseM_t&, int depth) const override {}
208 void front_multiply_2d
209 (integer_t, integer_t, const std::vector<integer_t>&, const DistM_t&,
210 DistM_t&, DistM_t&, int) const override {}
211 void front_multiply_2d
212 (Trans op, integer_t, integer_t, const std::vector<integer_t>&,
213 const DistM_t&, DistM_t&, int) const override {}
214
215 CSRGraph<integer_t> extract_graph_sep_CB
216 (int ordering_level, integer_t lo, integer_t hi,
217 const std::vector<integer_t>& upd) const override {
218 return CSRGraph<integer_t>(); };
219 CSRGraph<integer_t> extract_graph_CB_sep
220 (int ordering_level, integer_t lo, integer_t hi,
221 const std::vector<integer_t>& upd) const override {
222 return CSRGraph<integer_t>(); };
223 CSRGraph<integer_t> extract_graph_CB
224 (int ordering_level, const std::vector<integer_t>& upd) const override {
225 return CSRGraph<integer_t>(); };
226
227 void front_multiply_F11
228 (Trans op, integer_t slo, integer_t shi,
229 const DenseM_t& R, DenseM_t& S, int depth) const override {};
230 void front_multiply_F12
231 (Trans op, integer_t slo, integer_t shi, const std::vector<integer_t>& upd,
232 const DenseM_t& R, DenseM_t& S, int depth) const override {};
233 void front_multiply_F21
234 (Trans op, integer_t slo, integer_t shi, const std::vector<integer_t>& upd,
235 const DenseM_t& R, DenseM_t& S, int depth) const override {};
236#endif //DOXYGEN_SHOULD_SKIP_THIS
237
238 protected:
239 void split_diag_offdiag();
240 void setup_spmv_buffers() const;
241
242 // TODO use MPIComm
243 MPIComm comm_;
244
250 std::vector<integer_t> dist_;
251
256 std::vector<integer_t> offdiag_start_;
257
258 integer_t brow_; // = dist_[rank], dist_[rank+1]
259 integer_t lrows_; // = erow_ - brow_
260 integer_t lnnz_; // = ptr_[local_rows]
261
262 mutable SPMVBuffers<scalar_t,integer_t> spmv_bufs_;
263
264
268 void scale(const std::vector<scalar_t>& lDr,
269 const std::vector<scalar_t>& gDc) override;
270 void scale_real(const std::vector<real_t>& lDr,
271 const std::vector<real_t>& gDc) override;
272
273 void sort_rows();
274
275 using CSM_t::n_;
276 using CSM_t::nnz_;
277 using CSM_t::ptr_;
278 using CSM_t::ind_;
279 using CSM_t::val_;
280 using CSM_t::symm_sparse_;
281 };
282
294 template<typename scalar_t, typename integer_t, typename cast_t>
295 CSRMatrixMPI<cast_t,integer_t>
297
298} // end namespace strumpack
299
300#endif // STRUMPACK_CSRMATRIX_MPI_HPP
Contains the compressed sparse row matrix storage class.
Contains some simple C++ MPI wrapper utilities.
Definition CompressedSparseMatrix.hpp:49
Block-row distributed compressed sparse row storage.
Definition CSRMatrixMPI.hpp:72
void permute(const integer_t *iorder, const integer_t *order) override
void spmv(const scalar_t *x, scalar_t *y) const override
void spmv(const DenseM_t &x, DenseM_t &y) const override
Match_t matching(MatchingJob job, bool apply=true) override
Class for storing a compressed sparse row matrix (single node).
Definition CSRMatrix.hpp:55
Abstract base class for compressed sparse matrix storage.
Definition CompressedSparseMatrix.hpp:149
bool symm_sparse() const
Definition CompressedSparseMatrix.hpp:284
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition DenseMatrix.hpp:139
2D block cyclicly distributed matrix, as used by ScaLAPACK.
Definition DistributedMatrix.hpp:84
Definition CompressedSparseMatrix.hpp:97
Wrapper class around an MPI_Comm object.
Definition MPIWrapper.hpp:173
MPI_Comm comm() const
Definition MPIWrapper.hpp:240
Definition CompressedSparseMatrix.hpp:56
Definition CSRMatrixMPI.hpp:48
Definition StrumpackOptions.hpp:44
CSRMatrix< cast_t, integer_t > cast_matrix(const CSRMatrix< scalar_t, integer_t > &mat)
MatchingJob
Definition StrumpackOptions.hpp:120
Trans
Definition DenseMatrix.hpp:51