28 #ifndef DISTRIBUTED_MATRIX_HPP
29 #define DISTRIBUTED_MATRIX_HPP
36 #include "misc/RandomWrapper.hpp"
42 inline int indxl2g(
int INDXLOC,
int NB,
int IPROC,
int ISRCPROC,
int NPROCS)
43 {
return NPROCS*NB*((INDXLOC-1)/NB) + (INDXLOC-1) % NB +
44 ((NPROCS+IPROC-ISRCPROC) % NPROCS)*NB + 1; }
45 inline int indxg2l(
int INDXGLOB,
int NB,
int IPROC,
int ISRCPROC,
int NPROCS)
46 {
return NB*((INDXGLOB-1)/(NB*NPROCS)) + (INDXGLOB-1) % NB + 1; }
47 inline int indxg2p(
int INDXGLOB,
int NB,
int IPROC,
int ISRCPROC,
int NPROCS)
48 {
return ( ISRCPROC + (INDXGLOB - 1) / NB ) % NPROCS; }
51 template<
typename scalar_t>
class DistributedMatrix {
52 using real_t =
typename RealType<scalar_t>::value_type;
55 const BLACSGrid* grid_ =
nullptr;
56 scalar_t* data_ =
nullptr;
63 DistributedMatrix(
const BLACSGrid* g,
const DenseMatrix<scalar_t>& m);
64 DistributedMatrix(
const BLACSGrid* g, DenseMatrix<scalar_t>&& m);
65 DistributedMatrix(
const BLACSGrid* g, DenseMatrixWrapper<scalar_t>&& m);
66 DistributedMatrix(
const BLACSGrid* g,
int M,
int N,
67 const DistributedMatrix<scalar_t>& m,
69 DistributedMatrix(
const BLACSGrid* g,
int M,
int N);
70 DistributedMatrix(
const BLACSGrid* g,
int M,
int N,
int MB,
int NB);
71 DistributedMatrix(
const BLACSGrid* g,
int desc[9]);
73 DistributedMatrix(
const DistributedMatrix<scalar_t>& m);
74 DistributedMatrix(DistributedMatrix<scalar_t>&& m);
75 virtual ~DistributedMatrix();
77 DistributedMatrix<scalar_t>&
78 operator=(
const DistributedMatrix<scalar_t>& m);
79 DistributedMatrix<scalar_t>&
80 operator=(DistributedMatrix<scalar_t>&& m);
83 inline const int* desc()
const {
return desc_; }
84 inline int* desc() {
return desc_; }
85 inline bool active()
const {
return grid() && grid()->
active(); }
87 inline const BLACSGrid* grid()
const {
return grid_; }
88 inline const MPIComm& Comm()
const {
return grid()->
Comm(); }
89 inline MPI_Comm comm()
const {
return Comm().
comm(); }
91 inline int ctxt()
const {
return grid() ? grid()->
ctxt() : -1; }
92 inline int ctxt_all()
const {
return grid() ? grid()->
ctxt_all() : -1; }
94 virtual int rows()
const {
return desc_[2]; }
95 virtual int cols()
const {
return desc_[3]; }
96 inline int lrows()
const {
return lrows_; }
97 inline int lcols()
const {
return lcols_; }
98 inline int ld()
const {
return lrows_; }
99 inline int MB()
const {
return desc_[4]; }
100 inline int NB()
const {
return desc_[5]; }
101 inline int rowblocks()
const {
return std::ceil(
float(lrows()) / MB()); }
102 inline int colblocks()
const {
return std::ceil(
float(lcols()) / NB()); }
104 virtual int I()
const {
return 1; }
105 virtual int J()
const {
return 1; }
106 virtual void lranges(
int& rlo,
int& rhi,
int& clo,
int& chi)
const;
108 inline const scalar_t* data()
const {
return data_; }
109 inline scalar_t* data() {
return data_; }
110 inline const scalar_t& operator()(
int r,
int c)
const
111 {
return data_[r+ld()*c]; }
112 inline scalar_t& operator()(
int r,
int c) {
return data_[r+ld()*c]; }
114 inline int prow()
const { assert(grid());
return grid()->
prow(); }
115 inline int pcol()
const { assert(grid());
return grid()->
pcol(); }
116 inline int nprows()
const { assert(grid());
return grid()->
nprows(); }
117 inline int npcols()
const { assert(grid());
return grid()->
npcols(); }
119 inline bool is_master()
const {
return grid() && prow() == 0 && pcol() == 0; }
120 inline int rowl2g(
int row)
const { assert(grid());
121 return indxl2g(row+1, MB(), prow(), 0, nprows()) - I(); }
122 inline int coll2g(
int col)
const { assert(grid());
123 return indxl2g(col+1, NB(), pcol(), 0, npcols()) - J(); }
124 inline int rowg2l(
int row)
const { assert(grid());
125 return indxg2l(row+I(), MB(), prow(), 0, nprows()) - 1; }
126 inline int colg2l(
int col)
const { assert(grid());
127 return indxg2l(col+J(), NB(), pcol(), 0, npcols()) - 1; }
128 inline int rowg2p(
int row)
const { assert(grid());
129 return indxg2p(row+I(), MB(), prow(), 0, nprows()); }
130 inline int colg2p(
int col)
const { assert(grid());
131 return indxg2p(col+J(), NB(), pcol(), 0, npcols()); }
132 inline int rank(
int r,
int c)
const {
133 return rowg2p(r) + colg2p(c) * nprows(); }
134 inline bool is_local(
int r,
int c)
const { assert(grid());
135 return rowg2p(r) == prow() && colg2p(c) == pcol();
138 bool fixed()
const {
return MB()==default_MB && NB()==default_NB; }
139 int rowl2g_fixed(
int row)
const {
140 assert(grid() && fixed());
141 return indxl2g(row+1, default_MB, prow(), 0, nprows()) - I(); }
142 int coll2g_fixed(
int col)
const {
143 assert(grid() && fixed());
144 return indxl2g(col+1, default_NB, pcol(), 0, npcols()) - J(); }
145 int rowg2l_fixed(
int row)
const {
146 assert(grid() && fixed());
147 return indxg2l(row+I(), default_MB, prow(), 0, nprows()) - 1; }
148 int colg2l_fixed(
int col)
const {
149 assert(grid() && fixed());
150 return indxg2l(col+J(), default_NB, pcol(), 0, npcols()) - 1; }
151 int rowg2p_fixed(
int row)
const {
152 assert(grid() && fixed());
153 return indxg2p(row+I(), default_MB, prow(), 0, nprows()); }
154 int colg2p_fixed(
int col)
const {
155 assert(grid() && fixed());
156 return indxg2p(col+J(), default_NB, pcol(), 0, npcols()); }
157 int rank_fixed(
int r,
int c)
const {
158 assert(grid() && fixed());
return rowg2p_fixed(r) + colg2p_fixed(c) * nprows(); }
159 bool is_local_fixed(
int r,
int c)
const {
160 assert(grid() && fixed());
161 return rowg2p_fixed(r) == prow() && colg2p_fixed(c) == pcol(); }
164 const scalar_t& global(
int r,
int c)
const
165 { assert(is_local(r, c));
return operator()(rowg2l(r),colg2l(c)); }
166 scalar_t& global(
int r,
int c)
167 { assert(is_local(r, c));
return operator()(rowg2l(r),colg2l(c)); }
168 scalar_t& global_fixed(
int r,
int c) {
169 assert(is_local(r, c)); assert(fixed());
170 return operator()(rowg2l_fixed(r),colg2l_fixed(c)); }
171 void global(
int r,
int c, scalar_t v) {
172 if (active() && is_local(r, c)) operator()(rowg2l(r),colg2l(c)) = v; }
173 scalar_t all_global(
int r,
int c)
const;
175 void print()
const { print(
"A"); }
176 void print(std::string name,
int precision=15)
const;
177 void print_to_file(std::string name, std::string filename,
179 void print_to_files(std::string name,
int precision=16)
const;
181 void random(random::RandomGeneratorBase<
typename RealType<scalar_t>::
184 void fill(scalar_t a);
186 void shift(scalar_t sigma);
188 virtual void resize(std::size_t m, std::size_t n);
189 virtual void hconcat(
const DistributedMatrix<scalar_t>& b);
190 DistributedMatrix<scalar_t> transpose()
const;
192 void laswp(
const std::vector<int>& P,
bool fwd);
194 DistributedMatrix<scalar_t>
195 extract_rows(
const std::vector<std::size_t>& Ir)
const;
196 DistributedMatrix<scalar_t>
197 extract_cols(
const std::vector<std::size_t>& Ic)
const;
199 DistributedMatrix<scalar_t>
200 extract(
const std::vector<std::size_t>& I,
201 const std::vector<std::size_t>& J)
const;
202 DistributedMatrix<scalar_t>& add(
const DistributedMatrix<scalar_t>& B);
203 DistributedMatrix<scalar_t>&
204 scaled_add(scalar_t alpha,
const DistributedMatrix<scalar_t>& B);
207 real_t normF()
const;
208 real_t norm1()
const;
209 real_t normI()
const;
211 virtual std::size_t memory()
const
212 {
return sizeof(scalar_t)*std::size_t(lrows())*std::size_t(lcols()); }
213 virtual std::size_t total_memory()
const
214 {
return sizeof(scalar_t)*std::size_t(rows())*std::size_t(cols()); }
215 virtual std::size_t nonzeros()
const
216 {
return std::size_t(lrows())*std::size_t(lcols()); }
217 virtual std::size_t total_nonzeros()
const
218 {
return std::size_t(rows())*std::size_t(cols()); }
220 void scatter(
const DenseMatrix<scalar_t>& a);
221 DenseMatrix<scalar_t> gather()
const;
222 DenseMatrix<scalar_t> all_gather()
const;
224 DenseMatrix<scalar_t> dense_and_clear();
225 DenseMatrix<scalar_t> dense()
const;
226 DenseMatrixWrapper<scalar_t> dense_wrapper();
228 std::vector<int> LU();
229 int LU(std::vector<int>&);
231 DistributedMatrix<scalar_t>
232 solve(
const DistributedMatrix<scalar_t>& b,
233 const std::vector<int>& piv)
const;
235 void LQ(DistributedMatrix<scalar_t>& L,
236 DistributedMatrix<scalar_t>& Q)
const;
238 void orthogonalize(scalar_t& r_max, scalar_t& r_min);
240 void ID_column(DistributedMatrix<scalar_t>& X, std::vector<int>& piv,
241 std::vector<std::size_t>& ind,
242 real_t rel_tol, real_t abs_tol,
int max_rank);
243 void ID_row(DistributedMatrix<scalar_t>& X, std::vector<int>& piv,
244 std::vector<std::size_t>& ind, real_t rel_tol, real_t abs_tol,
245 int max_rank,
const BLACSGrid* grid_T);
247 static const int default_MB = STRUMPACK_PBLAS_BLOCKSIZE;
248 static const int default_NB = STRUMPACK_PBLAS_BLOCKSIZE;
255 template<
typename scalar_t>
void copy
256 (std::size_t m, std::size_t n,
const DistributedMatrix<scalar_t>& a,
257 std::size_t ia, std::size_t ja, DenseMatrix<scalar_t>& b,
258 int dest,
int context_all);
260 template<
typename scalar_t>
void copy
261 (std::size_t m, std::size_t n,
const DenseMatrix<scalar_t>& a,
int src,
262 DistributedMatrix<scalar_t>& b, std::size_t ib, std::size_t jb,
266 template<
typename scalar_t>
void copy
267 (std::size_t m, std::size_t n,
const DistributedMatrix<scalar_t>& a,
268 std::size_t ia, std::size_t ja, DistributedMatrix<scalar_t>& b,
269 std::size_t ib, std::size_t jb,
int context_all);
276 template<
typename scalar_t>
279 int _rows, _cols, _i, _j;
282 _rows(0), _cols(0), _i(0), _j(0) {}
289 std::size_t i, std::size_t j);
293 int MB,
int NB, scalar_t* A);
304 int rows()
const override {
return _rows; }
305 int cols()
const override {
return _cols; }
306 int I()
const override {
return _i+1; }
307 int J()
const override {
return _j+1; }
308 void lranges(
int& rlo,
int& rhi,
int& clo,
int& chi)
const override;
310 void resize(std::size_t m, std::size_t n)
override { assert(1); }
314 std::size_t memory()
const override {
return 0; }
315 std::size_t total_memory()
const override {
return 0; }
316 std::size_t nonzeros()
const override {
return 0; }
317 std::size_t total_nonzeros()
const override {
return 0; }
328 template<
typename scalar_t>
long long int
330 if (!a.is_master())
return 0;
331 return (is_complex<scalar_t>() ? 4:1) *
332 blas::getrf_flops(a.rows(), a.cols());
335 template<
typename scalar_t>
long long int
336 solve_flops(
const DistributedMatrix<scalar_t>& b) {
337 if (!b.is_master())
return 0;
338 return (is_complex<scalar_t>() ? 4:1) *
339 blas::getrs_flops(b.rows(), b.cols());
342 template<
typename scalar_t>
long long int
343 LQ_flops(
const DistributedMatrix<scalar_t>& a) {
344 if (!a.is_master())
return 0;
345 auto minrc = std::min(a.rows(), a.cols());
346 return (is_complex<scalar_t>() ? 4:1) *
347 (blas::gelqf_flops(a.rows(), a.cols()) +
348 blas::xxglq_flops(a.cols(), a.cols(), minrc));
351 template<
typename scalar_t>
long long int
352 ID_row_flops(
const DistributedMatrix<scalar_t>& a,
int rank) {
353 if (!a.is_master())
return 0;
354 return (is_complex<scalar_t>() ? 4:1) *
355 (blas::geqp3_flops(a.cols(), a.rows())
359 template<
typename scalar_t>
long long int
360 trsm_flops(
Side s, scalar_t alpha,
const DistributedMatrix<scalar_t>& a,
361 const DistributedMatrix<scalar_t>& b) {
362 if (!a.is_master())
return 0;
363 return (is_complex<scalar_t>() ? 4:1) *
367 template<
typename scalar_t>
long long int
369 const DistributedMatrix<scalar_t>& a,
370 const DistributedMatrix<scalar_t>& b, scalar_t beta) {
371 if (!a.is_master())
return 0;
372 return (is_complex<scalar_t>() ? 4:1) *
374 ((ta==
Trans::N) ? a.rows() : a.cols(),
375 (tb==
Trans::N) ? b.cols() : b.rows(),
376 (ta==
Trans::N) ? a.cols() : a.rows(), alpha, beta);
379 template<
typename scalar_t>
long long int
380 gemv_flops(
Trans ta,
const DistributedMatrix<scalar_t>& a,
381 scalar_t alpha, scalar_t beta) {
382 if (!a.is_master())
return 0;
383 auto m = (ta==
Trans::N) ? a.rows() : a.cols();
384 auto n = (ta==
Trans::N) ? a.cols() : a.rows();
385 return (is_complex<scalar_t>() ? 4:1) *
386 ((alpha != scalar_t(0.)) * m * (n * 2 - 1) +
387 (alpha != scalar_t(1.) && alpha != scalar_t(0.)) * m +
388 (beta != scalar_t(0.) && beta != scalar_t(1.)) * m +
389 (alpha != scalar_t(0.) && beta != scalar_t(0.)) * m);
392 template<
typename scalar_t>
long long int
394 if (!a.is_master())
return 0;
395 auto minrc = std::min(a.rows(), a.cols());
396 return (is_complex<scalar_t>() ? 4:1) *
397 (blas::geqrf_flops(a.rows(), minrc) +
398 blas::xxgqr_flops(a.rows(), minrc, minrc));
402 template<
typename scalar_t>
403 std::unique_ptr<const DistributedMatrixWrapper<scalar_t>>
404 ConstDistributedMatrixWrapperPtr
405 (std::size_t m, std::size_t n,
const DistributedMatrix<scalar_t>& D,
406 std::size_t i, std::size_t j) {
407 return std::unique_ptr<const DistributedMatrixWrapper<scalar_t>>
408 (
new DistributedMatrixWrapper<scalar_t>
409 (m, n,
const_cast<DistributedMatrix<scalar_t>&
>(D), i, j));
413 template<
typename scalar_t>
void gemm
414 (
Trans ta,
Trans tb, scalar_t alpha,
const DistributedMatrix<scalar_t>& A,
415 const DistributedMatrix<scalar_t>& B, scalar_t beta,
416 DistributedMatrix<scalar_t>& C);
418 template<
typename scalar_t>
void trsm
420 const DistributedMatrix<scalar_t>& A, DistributedMatrix<scalar_t>& B);
422 template<
typename scalar_t>
void trsv
423 (
UpLo ul,
Trans ta,
Diag d,
const DistributedMatrix<scalar_t>& A,
424 DistributedMatrix<scalar_t>& B);
426 template<
typename scalar_t>
void gemv
427 (
Trans ta, scalar_t alpha,
const DistributedMatrix<scalar_t>& A,
428 const DistributedMatrix<scalar_t>& X, scalar_t beta,
429 DistributedMatrix<scalar_t>& Y);
431 template<
typename scalar_t> DistributedMatrix<scalar_t>
vconcat
432 (
int cols,
int arows,
int brows,
const DistributedMatrix<scalar_t>& a,
433 const DistributedMatrix<scalar_t>& b,
const BLACSGrid* gnew,
int cxt_all);
435 template<
typename scalar_t>
void subgrid_copy_to_buffers
436 (
const DistributedMatrix<scalar_t>& a,
const DistributedMatrix<scalar_t>& b,
437 int p0,
int npr,
int npc, std::vector<std::vector<scalar_t>>& sbuf);
439 template<
typename scalar_t>
void subproc_copy_to_buffers
440 (
const DenseMatrix<scalar_t>& a,
const DistributedMatrix<scalar_t>& b,
441 int p0,
int npr,
int npc, std::vector<std::vector<scalar_t>>& sbuf);
443 template<
typename scalar_t>
void subgrid_add_from_buffers
444 (
const BLACSGrid* subg,
int master, DistributedMatrix<scalar_t>& b,
445 std::vector<scalar_t*>& pbuf);
449 #endif // DISTRIBUTED_MATRIX_HPP