28 #ifndef DISTRIBUTED_MATRIX_HPP
29 #define DISTRIBUTED_MATRIX_HPP
36 #include "misc/RandomWrapper.hpp"
42 inline int indxl2g(
int INDXLOC,
int NB,
int IPROC,
int ISRCPROC,
int NPROCS)
43 {
return NPROCS*NB*((INDXLOC-1)/NB) + (INDXLOC-1) % NB +
44 ((NPROCS+IPROC-ISRCPROC) % NPROCS)*NB + 1; }
45 inline int indxg2l(
int INDXGLOB,
int NB,
int IPROC,
int ISRCPROC,
int NPROCS)
46 {
return NB*((INDXGLOB-1)/(NB*NPROCS)) + (INDXGLOB-1) % NB + 1; }
47 inline int indxg2p(
int INDXGLOB,
int NB,
int IPROC,
int ISRCPROC,
int NPROCS)
48 {
return ( ISRCPROC + (INDXGLOB - 1) / NB ) % NPROCS; }
51 template<
typename scalar_t>
class DistributedMatrix {
52 using real_t =
typename RealType<scalar_t>::value_type;
55 const BLACSGrid* grid_ =
nullptr;
56 scalar_t* data_ =
nullptr;
63 DistributedMatrix(
const BLACSGrid* g,
const DenseMatrix<scalar_t>& m);
64 DistributedMatrix(
const BLACSGrid* g, DenseMatrix<scalar_t>&& m);
65 DistributedMatrix(
const BLACSGrid* g, DenseMatrixWrapper<scalar_t>&& m);
66 DistributedMatrix(
const BLACSGrid* g,
int M,
int N,
67 const DistributedMatrix<scalar_t>& m,
69 DistributedMatrix(
const BLACSGrid* g,
int M,
int N);
70 DistributedMatrix(
const BLACSGrid* g,
int M,
int N,
int MB,
int NB);
71 DistributedMatrix(
const BLACSGrid* g,
int desc[9]);
73 DistributedMatrix(
const DistributedMatrix<scalar_t>& m);
74 DistributedMatrix(DistributedMatrix<scalar_t>&& m);
75 virtual ~DistributedMatrix();
77 DistributedMatrix<scalar_t>&
78 operator=(
const DistributedMatrix<scalar_t>& m);
79 DistributedMatrix<scalar_t>&
80 operator=(DistributedMatrix<scalar_t>&& m);
83 inline const int* desc()
const {
return desc_; }
84 inline int* desc() {
return desc_; }
85 inline bool active()
const {
return grid() && grid()->
active(); }
87 inline const BLACSGrid* grid()
const {
return grid_; }
88 inline const MPIComm& Comm()
const {
return grid()->
Comm(); }
89 inline MPI_Comm comm()
const {
return Comm().
comm(); }
91 inline int ctxt()
const {
return grid() ? grid()->
ctxt() : -1; }
92 inline int ctxt_all()
const {
return grid() ? grid()->
ctxt_all() : -1; }
94 virtual int rows()
const {
return desc_[2]; }
95 virtual int cols()
const {
return desc_[3]; }
96 inline int lrows()
const {
return lrows_; }
97 inline int lcols()
const {
return lcols_; }
98 inline int ld()
const {
return lrows_; }
99 inline int MB()
const {
return desc_[4]; }
100 inline int NB()
const {
return desc_[5]; }
101 inline int rowblocks()
const {
return std::ceil(
float(lrows()) / MB()); }
102 inline int colblocks()
const {
return std::ceil(
float(lcols()) / NB()); }
104 virtual int I()
const {
return 1; }
105 virtual int J()
const {
return 1; }
106 virtual void lranges(
int& rlo,
int& rhi,
int& clo,
int& chi)
const;
108 inline const scalar_t* data()
const {
return data_; }
109 inline scalar_t* data() {
return data_; }
110 inline const scalar_t& operator()(
int r,
int c)
const
111 {
return data_[r+ld()*c]; }
112 inline scalar_t& operator()(
int r,
int c) {
return data_[r+ld()*c]; }
114 inline int prow()
const { assert(grid());
return grid()->
prow(); }
115 inline int pcol()
const { assert(grid());
return grid()->
pcol(); }
116 inline int nprows()
const { assert(grid());
return grid()->
nprows(); }
117 inline int npcols()
const { assert(grid());
return grid()->
npcols(); }
119 inline bool is_master()
const {
return grid() && prow() == 0 && pcol() == 0; }
120 inline int rowl2g(
int row)
const { assert(grid());
121 return indxl2g(row+1, MB(), prow(), 0, nprows()) - I(); }
122 inline int coll2g(
int col)
const { assert(grid());
123 return indxl2g(col+1, NB(), pcol(), 0, npcols()) - J(); }
124 inline int rowg2l(
int row)
const { assert(grid());
125 return indxg2l(row+I(), MB(), prow(), 0, nprows()) - 1; }
126 inline int colg2l(
int col)
const { assert(grid());
127 return indxg2l(col+J(), NB(), pcol(), 0, npcols()) - 1; }
128 inline int rowg2p(
int row)
const { assert(grid());
129 return indxg2p(row+I(), MB(), prow(), 0, nprows()); }
130 inline int colg2p(
int col)
const { assert(grid());
131 return indxg2p(col+J(), NB(), pcol(), 0, npcols()); }
132 inline int rank(
int r,
int c)
const {
133 return rowg2p(r) + colg2p(c) * nprows(); }
134 inline bool is_local(
int r,
int c)
const { assert(grid());
135 return rowg2p(r) == prow() && colg2p(c) == pcol();
138 inline bool fixed()
const {
return MB()==default_MB && NB()==default_NB; }
139 inline int rowl2g_fixed(
int row)
const {
140 assert(grid() && fixed());
141 return indxl2g(row+1, default_MB, prow(), 0, nprows()) - I(); }
142 inline int coll2g_fixed(
int col)
const {
143 assert(grid() && fixed());
144 return indxl2g(col+1, default_NB, pcol(), 0, npcols()) - J(); }
145 inline int rowg2l_fixed(
int row)
const {
146 assert(grid() && fixed());
147 return indxg2l(row+I(), default_MB, prow(), 0, nprows()) - 1; }
148 inline int colg2l_fixed(
int col)
const {
149 assert(grid() && fixed());
150 return indxg2l(col+J(), default_NB, pcol(), 0, npcols()) - 1; }
151 inline int rowg2p_fixed(
int row)
const {
152 assert(grid() && fixed());
153 return indxg2p(row+I(), default_MB, prow(), 0, nprows()); }
154 inline int colg2p_fixed(
int col)
const {
155 assert(grid() && fixed());
156 return indxg2p(col+J(), default_NB, pcol(), 0, npcols()); }
157 inline int rank_fixed(
int r,
int c)
const {
158 assert(grid() && fixed());
return rowg2p_fixed(r) + colg2p_fixed(c) * nprows(); }
159 inline bool is_local_fixed(
int r,
int c)
const {
160 assert(grid() && fixed());
161 return rowg2p_fixed(r) == prow() && colg2p_fixed(c) == pcol(); }
164 inline const scalar_t& global(
int r,
int c)
const
165 { assert(is_local(r, c));
return operator()(rowg2l(r),colg2l(c)); }
166 inline scalar_t& global(
int r,
int c)
167 { assert(is_local(r, c));
return operator()(rowg2l(r),colg2l(c)); }
168 inline scalar_t& global_fixed(
int r,
int c) {
169 assert(is_local(r, c)); assert(fixed());
170 return operator()(rowg2l_fixed(r),colg2l_fixed(c)); }
171 inline void global(
int r,
int c, scalar_t v) {
172 if (active() && is_local(r, c)) operator()(rowg2l(r),colg2l(c)) = v; }
173 scalar_t all_global(
int r,
int c)
const;
175 void print()
const { print(
"A"); }
176 void print(std::string name,
int precision=15)
const;
178 (std::string name, std::string filename,
int width=8)
const;
180 (std::string name,
int precision=16)
const;
183 (random::RandomGeneratorBase<
typename RealType<scalar_t>::
186 void fill(scalar_t a);
188 void shift(scalar_t sigma);
190 virtual void resize(std::size_t m, std::size_t n);
191 virtual void hconcat(
const DistributedMatrix<scalar_t>& b);
192 DistributedMatrix<scalar_t> transpose()
const;
194 void laswp(
const std::vector<int>& P,
bool fwd);
196 DistributedMatrix<scalar_t>
197 extract_rows(
const std::vector<std::size_t>& Ir)
const;
198 DistributedMatrix<scalar_t>
199 extract_cols(
const std::vector<std::size_t>& Ic)
const;
200 DistributedMatrix<scalar_t> extract
201 (
const std::vector<std::size_t>& I,
202 const std::vector<std::size_t>& J)
const;
203 DistributedMatrix<scalar_t>& add(
const DistributedMatrix<scalar_t>& B);
204 DistributedMatrix<scalar_t>& scaled_add
205 (scalar_t alpha,
const DistributedMatrix<scalar_t>& B);
206 typename RealType<scalar_t>::value_type norm()
const;
207 typename RealType<scalar_t>::value_type normF()
const;
208 typename RealType<scalar_t>::value_type norm1()
const;
209 typename RealType<scalar_t>::value_type normI()
const;
210 virtual std::size_t memory()
const
211 {
return sizeof(scalar_t)*std::size_t(lrows())*std::size_t(lcols()); }
212 virtual std::size_t total_memory()
const
213 {
return sizeof(scalar_t)*std::size_t(rows())*std::size_t(cols()); }
214 virtual std::size_t nonzeros()
const
215 {
return std::size_t(lrows())*std::size_t(lcols()); }
216 virtual std::size_t total_nonzeros()
const
217 {
return std::size_t(rows())*std::size_t(cols()); }
219 void scatter(
const DenseMatrix<scalar_t>& a);
220 DenseMatrix<scalar_t> gather()
const;
221 DenseMatrix<scalar_t> all_gather()
const;
223 DenseMatrix<scalar_t> dense_and_clear();
224 DenseMatrix<scalar_t> dense()
const;
225 DenseMatrixWrapper<scalar_t> dense_wrapper();
227 std::vector<int> LU();
228 DistributedMatrix<scalar_t> solve
229 (
const DistributedMatrix<scalar_t>& b,
const std::vector<int>& piv)
const;
231 (DistributedMatrix<scalar_t>& L, DistributedMatrix<scalar_t>& Q)
const;
232 void orthogonalize(scalar_t& r_max, scalar_t& r_min);
234 (DistributedMatrix<scalar_t>& X, std::vector<int>& piv,
235 std::vector<std::size_t>& ind, real_t rel_tol, real_t abs_tol,
int max_rank);
237 (DistributedMatrix<scalar_t>& X, std::vector<int>& piv,
238 std::vector<std::size_t>& ind, real_t rel_tol, real_t abs_tol,
239 int max_rank,
const BLACSGrid* grid_T);
241 static const int default_MB = STRUMPACK_PBLAS_BLOCKSIZE;
242 static const int default_NB = STRUMPACK_PBLAS_BLOCKSIZE;
249 template<
typename scalar_t>
void copy
250 (std::size_t m, std::size_t n,
const DistributedMatrix<scalar_t>& a,
251 std::size_t ia, std::size_t ja, DenseMatrix<scalar_t>& b,
252 int dest,
int context_all);
254 template<
typename scalar_t>
void copy
255 (std::size_t m, std::size_t n,
const DenseMatrix<scalar_t>& a,
int src,
256 DistributedMatrix<scalar_t>& b, std::size_t ib, std::size_t jb,
260 template<
typename scalar_t>
void copy
261 (std::size_t m, std::size_t n,
const DistributedMatrix<scalar_t>& a,
262 std::size_t ia, std::size_t ja, DistributedMatrix<scalar_t>& b,
263 std::size_t ib, std::size_t jb,
int context_all);
270 template<
typename scalar_t>
273 int _rows, _cols, _i, _j;
276 _rows(0), _cols(0), _i(0), _j(0) {}
283 std::size_t i, std::size_t j);
287 int MB,
int NB, scalar_t* A);
298 int rows()
const override {
return _rows; }
299 int cols()
const override {
return _cols; }
300 int I()
const override {
return _i+1; }
301 int J()
const override {
return _j+1; }
302 void lranges(
int& rlo,
int& rhi,
int& clo,
int& chi)
const override;
304 void resize(std::size_t m, std::size_t n)
override { assert(1); }
308 std::size_t memory()
const override {
return 0; }
309 std::size_t total_memory()
const override {
return 0; }
310 std::size_t nonzeros()
const override {
return 0; }
311 std::size_t total_nonzeros()
const override {
return 0; }
322 template<
typename scalar_t>
long long int
324 if (!a.is_master())
return 0;
325 return (is_complex<scalar_t>() ? 4:1) *
326 blas::getrf_flops(a.rows(), a.cols());
329 template<
typename scalar_t>
long long int
330 solve_flops(
const DistributedMatrix<scalar_t>& b) {
331 if (!b.is_master())
return 0;
332 return (is_complex<scalar_t>() ? 4:1) *
333 blas::getrs_flops(b.rows(), b.cols());
336 template<
typename scalar_t>
long long int
337 LQ_flops(
const DistributedMatrix<scalar_t>& a) {
338 if (!a.is_master())
return 0;
339 auto minrc = std::min(a.rows(), a.cols());
340 return (is_complex<scalar_t>() ? 4:1) *
341 (blas::gelqf_flops(a.rows(), a.cols()) +
342 blas::xxglq_flops(a.cols(), a.cols(), minrc));
345 template<
typename scalar_t>
long long int
346 ID_row_flops(
const DistributedMatrix<scalar_t>& a,
int rank) {
347 if (!a.is_master())
return 0;
348 return (is_complex<scalar_t>() ? 4:1) *
349 (blas::geqp3_flops(a.cols(), a.rows())
353 template<
typename scalar_t>
long long int
354 trsm_flops(
Side s, scalar_t alpha,
const DistributedMatrix<scalar_t>& a,
355 const DistributedMatrix<scalar_t>& b) {
356 if (!a.is_master())
return 0;
357 return (is_complex<scalar_t>() ? 4:1) *
361 template<
typename scalar_t>
long long int
363 const DistributedMatrix<scalar_t>& a,
364 const DistributedMatrix<scalar_t>& b, scalar_t beta) {
365 if (!a.is_master())
return 0;
366 return (is_complex<scalar_t>() ? 4:1) *
368 ((ta==
Trans::N) ? a.rows() : a.cols(),
369 (tb==
Trans::N) ? b.cols() : b.rows(),
370 (ta==
Trans::N) ? a.cols() : a.rows(), alpha, beta);
373 template<
typename scalar_t>
long long int
374 gemv_flops(
Trans ta,
const DistributedMatrix<scalar_t>& a,
375 scalar_t alpha, scalar_t beta) {
376 auto m = (ta==
Trans::N) ? a.rows() : a.cols();
377 auto n = (ta==
Trans::N) ? a.cols() : a.rows();
378 return (is_complex<scalar_t>() ? 4:1) *
379 ((alpha != scalar_t(0.)) * m * (n * 2 - 1) +
380 (alpha != scalar_t(1.) && alpha != scalar_t(0.)) * m +
381 (beta != scalar_t(0.) && beta != scalar_t(1.)) * m +
382 (alpha != scalar_t(0.) && beta != scalar_t(0.)) * m);
385 template<
typename scalar_t>
long long int
387 if (!a.is_master())
return 0;
388 auto minrc = std::min(a.rows(), a.cols());
389 return (is_complex<scalar_t>() ? 4:1) *
390 (blas::geqrf_flops(a.rows(), minrc) +
391 blas::xxgqr_flops(a.rows(), minrc, minrc));
395 template<
typename scalar_t>
396 std::unique_ptr<const DistributedMatrixWrapper<scalar_t>>
397 ConstDistributedMatrixWrapperPtr
398 (std::size_t m, std::size_t n,
const DistributedMatrix<scalar_t>& D,
399 std::size_t i, std::size_t j) {
400 return std::unique_ptr<const DistributedMatrixWrapper<scalar_t>>
401 (
new DistributedMatrixWrapper<scalar_t>
402 (m, n,
const_cast<DistributedMatrix<scalar_t>&
>(D), i, j));
406 template<
typename scalar_t>
void gemm
407 (
Trans ta,
Trans tb, scalar_t alpha,
const DistributedMatrix<scalar_t>& A,
408 const DistributedMatrix<scalar_t>& B, scalar_t beta,
409 DistributedMatrix<scalar_t>& C);
411 template<
typename scalar_t>
void trsm
413 const DistributedMatrix<scalar_t>& A, DistributedMatrix<scalar_t>& B);
415 template<
typename scalar_t>
void trsv
416 (
UpLo ul,
Trans ta,
Diag d,
const DistributedMatrix<scalar_t>& A,
417 DistributedMatrix<scalar_t>& B);
419 template<
typename scalar_t>
void gemv
420 (
Trans ta, scalar_t alpha,
const DistributedMatrix<scalar_t>& A,
421 const DistributedMatrix<scalar_t>& X, scalar_t beta,
422 DistributedMatrix<scalar_t>& Y);
424 template<
typename scalar_t> DistributedMatrix<scalar_t>
vconcat
425 (
int cols,
int arows,
int brows,
const DistributedMatrix<scalar_t>& a,
426 const DistributedMatrix<scalar_t>& b,
const BLACSGrid* gnew,
int cxt_all);
428 template<
typename scalar_t>
void subgrid_copy_to_buffers
429 (
const DistributedMatrix<scalar_t>& a,
const DistributedMatrix<scalar_t>& b,
430 int p0,
int npr,
int npc, std::vector<std::vector<scalar_t>>& sbuf);
432 template<
typename scalar_t>
void subproc_copy_to_buffers
433 (
const DenseMatrix<scalar_t>& a,
const DistributedMatrix<scalar_t>& b,
434 int p0,
int npr,
int npc, std::vector<std::vector<scalar_t>>& sbuf);
436 template<
typename scalar_t>
void subgrid_add_from_buffers
437 (
const BLACSGrid* subg,
int master, DistributedMatrix<scalar_t>& b,
438 std::vector<scalar_t*>& pbuf);
442 #endif // DISTRIBUTED_MATRIX_HPP