33 #ifndef STRUMPACK_MPI_WRAPPER_HPP
34 #define STRUMPACK_MPI_WRAPPER_HPP
44 #include "StrumpackConfig.hpp"
46 #define OMPI_SKIP_MPICXX 1
50 #include "Triplet.hpp"
80 template<>
inline MPI_Datatype mpi_type<std::complex<float>>() {
return MPI_C_FLOAT_COMPLEX; }
83 template<>
inline MPI_Datatype mpi_type<std::complex<double>>() {
return MPI_C_DOUBLE_COMPLEX; }
85 template<>
inline MPI_Datatype mpi_type<std::pair<int,int>>() {
return MPI_2INT; }
88 template<>
inline MPI_Datatype mpi_type<std::pair<long int,long int>>() {
89 static MPI_Datatype l_l_mpi_type = MPI_DATATYPE_NULL;
90 if (l_l_mpi_type == MPI_DATATYPE_NULL) {
92 (2, strumpack::mpi_type<long int>(), &l_l_mpi_type);
93 MPI_Type_commit(&l_l_mpi_type);
98 template<>
inline MPI_Datatype mpi_type<std::pair<long long int,long long int>>() {
99 static MPI_Datatype ll_ll_mpi_type = MPI_DATATYPE_NULL;
100 if (ll_ll_mpi_type == MPI_DATATYPE_NULL) {
103 MPI_Type_commit(&ll_ll_mpi_type);
105 return ll_ll_mpi_type;
130 req_ = std::unique_ptr<MPI_Request>(
new MPI_Request());
157 void wait() { MPI_Wait(req_.get(), MPI_STATUS_IGNORE); }
160 std::unique_ptr<MPI_Request> req_;
173 inline void wait_all(std::vector<MPIRequest>& reqs) {
174 for (
auto& r : reqs) r.wait();
178 inline void wait_all(std::vector<MPI_Request>& reqs) {
179 MPI_Waitall(reqs.size(), reqs.data(), MPI_STATUSES_IGNORE);
232 if (comm_ != MPI_COMM_NULL && comm_ != MPI_COMM_WORLD)
233 MPI_Comm_free(&comm_);
242 if (
this != &c) duplicate(c.
comm());
254 c.comm_ = MPI_COMM_NULL;
261 MPI_Comm
comm()
const {
return comm_; }
266 bool is_null()
const {
return comm_ == MPI_COMM_NULL; }
272 assert(comm_ != MPI_COMM_NULL);
274 MPI_Comm_rank(comm_, &r);
283 assert(comm_ != MPI_COMM_NULL);
285 MPI_Comm_size(comm_, &nprocs);
301 template<
typename T>
void
302 broadcast(std::vector<T>& sbuf)
const {
303 MPI_Bcast(sbuf.data(), sbuf.size(), mpi_type<T>(), 0, comm_);
306 template<
typename T>
void
307 broadcast_from(std::vector<T>& sbuf,
int src)
const {
308 MPI_Bcast(sbuf.data(), sbuf.size(), mpi_type<T>(), src, comm_);
311 template<
typename T, std::
size_t N>
void
312 broadcast(std::array<T,N>& sbuf)
const {
313 MPI_Bcast(sbuf.data(), sbuf.size(), mpi_type<T>(), 0, comm_);
316 template<
typename T>
void broadcast(T& data)
const {
317 MPI_Bcast(&data, 1, mpi_type<T>(), 0, comm_);
319 template<
typename T>
void broadcast_from(T& data,
int src)
const {
320 MPI_Bcast(&data, 1, mpi_type<T>(), src, comm_);
322 template<
typename T>
void
323 broadcast(T* sbuf, std::size_t ssize)
const {
324 MPI_Bcast(sbuf, ssize, mpi_type<T>(), 0, comm_);
326 template<
typename T>
void
327 broadcast_from(T* sbuf, std::size_t ssize,
int src)
const {
328 MPI_Bcast(sbuf, ssize, mpi_type<T>(), src, comm_);
332 void all_gather(T* buf, std::size_t rsize)
const {
334 (MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
335 buf, rsize, mpi_type<T>(), comm_);
339 void all_gather_v(T* buf,
const int* rcnts,
const int* displs)
const {
341 (MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, buf, rcnts, displs,
342 mpi_type<T>(), comm_);
346 void gather(T* sbuf,
int ssize,
int* rbuf,
int rsize,
int root)
const {
348 (sbuf, ssize, mpi_type<T>(), rbuf,
349 rsize, mpi_type<T>(), root, comm_);
353 void gather_v(T* sbuf,
int scnts, T* rbuf,
const int* rcnts,
354 const int* displs,
int root)
const {
356 (sbuf, scnts, mpi_type<T>(), rbuf, rcnts, displs,
357 mpi_type<T>(), root, comm_);
378 MPI_Isend(
const_cast<T*
>(sbuf.data()), sbuf.size(), mpi_type<T>(),
379 dest, tag, comm_, req.req_.get());
396 void isend(
const std::vector<T>& sbuf,
int dest,
int tag,
397 MPI_Request* req)
const {
399 MPI_Isend(
const_cast<T*
>(sbuf.data()), sbuf.size(), mpi_type<T>(),
400 dest, tag, comm_, req);
404 void isend(
const T* sbuf, std::size_t ssize,
int dest,
405 int tag, MPI_Request* req)
const {
407 MPI_Isend(
const_cast<T*
>(sbuf), ssize, mpi_type<T>(),
408 dest, tag, comm_, req);
411 void send(
const T* sbuf, std::size_t ssize,
int dest,
int tag)
const {
413 MPI_Send(
const_cast<T*
>(sbuf), ssize, mpi_type<T>(), dest, tag, comm_);
417 void isend(
const T& buf,
int dest,
int tag, MPI_Request* req)
const {
419 MPI_Isend(
const_cast<T*
>(&buf), 1, mpi_type<T>(),
420 dest, tag, comm_, req);
437 void send(
const std::vector<T>& sbuf,
int dest,
int tag)
const {
439 MPI_Send(
const_cast<T*
>(sbuf.data()), sbuf.size(),
440 mpi_type<T>(), dest, tag, comm_);
455 template<
typename T> std::vector<T>
recv(
int src,
int tag)
const {
457 MPI_Probe(src, tag, comm_, &stat);
459 MPI_Get_count(&stat, mpi_type<T>(), &msgsize);
461 std::vector<T> rbuf(msgsize);
462 MPI_Recv(rbuf.data(), msgsize, mpi_type<T>(), src, tag,
463 comm_, MPI_STATUS_IGNORE);
468 std::pair<int,std::vector<T>> recv_any_src(
int tag)
const {
470 MPI_Probe(MPI_ANY_SOURCE, tag, comm_, &stat);
472 MPI_Get_count(&stat, mpi_type<T>(), &msgsize);
473 std::vector<T> rbuf(msgsize);
474 MPI_Recv(rbuf.data(), msgsize, mpi_type<T>(), stat.MPI_SOURCE,
475 tag, comm_, MPI_STATUS_IGNORE);
476 return {stat.MPI_SOURCE, std::move(rbuf)};
479 template<
typename T>
T recv_one(
int src,
int tag)
const {
481 MPI_Recv(&t, 1, mpi_type<T>(), src, tag, comm_, MPI_STATUS_IGNORE);
486 void irecv(
const T* rbuf, std::size_t rsize,
int src,
487 int tag, MPI_Request* req)
const {
489 MPI_Irecv(
const_cast<T*
>(rbuf), rsize, mpi_type<T>(),
490 src, tag, comm_, req);
494 void recv(
const T* rbuf, std::size_t rsize,
int src,
int tag)
const {
497 MPI_Recv(
const_cast<T*
>(rbuf), rsize, mpi_type<T>(),
498 src, tag, comm_, &stat);
516 MPI_Allreduce(MPI_IN_PLACE, &t, 1, mpi_type<T>(), op, comm_);
534 template<
typename T>
T reduce(T t, MPI_Op op)
const {
536 MPI_Reduce(MPI_IN_PLACE, &t, 1, mpi_type<T>(), op, 0, comm_);
537 else MPI_Reduce(&t, &t, 1, mpi_type<T>(), op, 0, comm_);
556 template<
typename T>
void all_reduce(T* t,
int ssize, MPI_Op op)
const {
557 MPI_Allreduce(MPI_IN_PLACE, t, ssize, mpi_type<T>(), op, comm_);
560 template<
typename T>
void all_reduce(std::vector<T>& t, MPI_Op op)
const {
580 template<
typename T>
void
581 reduce(T* t,
int ssize, MPI_Op op,
int dest=0)
const {
583 MPI_Reduce(MPI_IN_PLACE, t, ssize, mpi_type<T>(), op, dest, comm_);
584 else MPI_Reduce(t, t, ssize, mpi_type<T>(), op, dest, comm_);
588 void all_to_all(
const T* sbuf,
int scnt, T* rbuf)
const {
590 (sbuf, scnt, mpi_type<T>(), rbuf, scnt, mpi_type<T>(), comm_);
593 template<
typename T,
typename A=std::allocator<T>> std::vector<T,A>
594 all_to_allv(
const T* sbuf,
int* scnts,
int* sdispls,
595 int* rcnts,
int* rdispls)
const {
596 std::size_t rsize = 0;
597 for (
int p=0; p<
size(); p++)
599 std::vector<T,A> rbuf(rsize);
601 (sbuf, scnts, sdispls, mpi_type<T>(),
602 rbuf.data(), rcnts, rdispls, mpi_type<T>(), comm_);
606 template<
typename T>
void
607 all_to_allv(
const T* sbuf,
int* scnts,
int* sdispls,
608 T* rbuf,
int* rcnts,
int* rdispls)
const {
610 (sbuf, scnts, sdispls, mpi_type<T>(),
611 rbuf, rcnts, rdispls, mpi_type<T>(), comm_);
631 template<
typename T,
typename A=std::allocator<T>>
void
632 all_to_all_v(std::vector<std::vector<T>>& sbuf, std::vector<T,A>& rbuf,
633 std::vector<T*>& pbuf)
const {
650 template<
typename T,
typename A=std::allocator<T>> std::vector<T,A>
652 std::vector<T,A> rbuf;
653 std::vector<T*> pbuf;
675 template<
typename T,
typename A=std::allocator<T>>
void
676 all_to_all_v(std::vector<std::vector<T>>& sbuf, std::vector<T,A>& rbuf,
677 std::vector<T*>& pbuf,
const MPI_Datatype Ttype)
const {
678 assert(sbuf.size() == std::size_t(
size()));
680 std::unique_ptr<int[]> iwork(
new int[4*P]);
681 auto ssizes = iwork.get();
682 auto rsizes = ssizes + P;
683 auto sdispl = ssizes + 2*P;
684 auto rdispl = ssizes + 3*P;
685 for (
int p=0; p<P; p++) {
687 static_cast<std::size_t
>(std::numeric_limits<int>::max())) {
688 std::cerr <<
"# ERROR: 32bit integer overflow in all_to_all_v!!"
692 ssizes[p] = sbuf[p].size();
696 std::size_t totssize = std::accumulate(ssizes, ssizes+P, std::size_t(0)),
697 totrsize = std::accumulate(rsizes, rsizes+P, std::size_t(0));
699 static_cast<std::size_t
>(std::numeric_limits<int>::max()) ||
701 static_cast<std::size_t
>(std::numeric_limits<int>::max())) {
707 rbuf.resize(totrsize);
708 std::unique_ptr<MPI_Request[]> reqs(
new MPI_Request[2*P]);
709 std::size_t displ = 0;
711 for (
int p=0; p<P; p++) {
712 pbuf[p] = rbuf.data() + displ;
713 MPI_Irecv(pbuf[p], rsizes[p], Ttype, p, 0, comm_, reqs.get()+p);
716 for (
int p=0; p<P; p++)
718 (sbuf[p].data(), ssizes[p], Ttype, p, 0, comm_, reqs.get()+P+p);
719 MPI_Waitall(2*P, reqs.get(), MPI_STATUSES_IGNORE);
720 std::vector<std::vector<T>>().swap(sbuf);
722 std::unique_ptr<T[]> sendbuf_(
new T[totssize]);
723 auto sendbuf = sendbuf_.get();
724 sdispl[0] = rdispl[0] = 0;
725 for (
int p=1; p<P; p++) {
726 sdispl[p] = sdispl[p-1] + ssizes[p-1];
727 rdispl[p] = rdispl[p-1] + rsizes[p-1];
729 for (
int p=0; p<P; p++)
730 std::copy(sbuf[p].begin(), sbuf[p].end(), sendbuf+sdispl[p]);
731 std::vector<std::vector<T>>().swap(sbuf);
732 rbuf.resize(totrsize);
733 MPI_Alltoallv(sendbuf, ssizes, sdispl, Ttype,
734 rbuf.data(), rsizes, rdispl, Ttype, comm_);
736 for (
int p=0; p<P; p++)
737 pbuf[p] = rbuf.data() + rdispl[p];
758 assert(P0 + P <=
size());
760 std::vector<int> sub_ranks(P);
761 for (
int i=0; i<P; i++)
762 sub_ranks[i] = P0 + i*stride;
763 MPI_Group group, sub_group;
764 MPI_Comm_group(comm_, &group);
765 MPI_Group_incl(group, P, sub_ranks.data(), &sub_group);
766 MPI_Comm_create(comm_, sub_group, &sub_comm.comm_);
767 MPI_Group_free(&group);
768 MPI_Group_free(&sub_group);
785 MPI_Group group, sub_group;
786 MPI_Comm_group(comm_, &group);
787 MPI_Group_incl(group, 1, &p, &sub_group);
788 MPI_Comm_create(comm_, sub_group, &c0.comm_);
789 MPI_Group_free(&group);
790 MPI_Group_free(&sub_group);
798 MPI_Pcontrol(1, name.c_str());
804 MPI_Pcontrol(-1, name.c_str());
807 static bool initialized() {
809 MPI_Initialized(&flag);
810 return static_cast<bool>(flag);
814 MPI_Comm comm_ = MPI_COMM_WORLD;
816 void duplicate(MPI_Comm c) {
817 if (c == MPI_COMM_NULL) comm_ = c;
818 else MPI_Comm_dup(c, &comm_);
831 assert(c != MPI_COMM_NULL);
833 MPI_Comm_rank(c, &rank);
845 assert(c != MPI_COMM_NULL);
847 MPI_Comm_size(c, &nprocs);
Contains the definition of some useful (global) variables.
Wrapper class around an MPI_Comm object.
Definition: MPIWrapper.hpp:194
virtual ~MPIComm()
Definition: MPIWrapper.hpp:231
MPIComm sub(int P0, int P, int stride=1) const
Definition: MPIWrapper.hpp:756
bool is_root() const
Definition: MPIWrapper.hpp:293
static void control_stop(const std::string &name)
Definition: MPIWrapper.hpp:803
T all_reduce(T t, MPI_Op op) const
Definition: MPIWrapper.hpp:515
void reduce(T *t, int ssize, MPI_Op op, int dest=0) const
Definition: MPIWrapper.hpp:581
MPIComm()
Definition: MPIWrapper.hpp:200
void barrier() const
Definition: MPIWrapper.hpp:299
MPIComm(const MPIComm &c)
Definition: MPIWrapper.hpp:217
std::vector< T, A > all_to_all_v(std::vector< std::vector< T >> &sbuf) const
Definition: MPIWrapper.hpp:651
void send(const std::vector< T > &sbuf, int dest, int tag) const
Definition: MPIWrapper.hpp:437
bool is_null() const
Definition: MPIWrapper.hpp:266
MPI_Comm comm() const
Definition: MPIWrapper.hpp:261
MPIComm & operator=(MPIComm &&c) noexcept
Definition: MPIWrapper.hpp:252
T reduce(T t, MPI_Op op) const
Definition: MPIWrapper.hpp:534
void isend(const std::vector< T > &sbuf, int dest, int tag, MPI_Request *req) const
Definition: MPIWrapper.hpp:396
void all_to_all_v(std::vector< std::vector< T >> &sbuf, std::vector< T, A > &rbuf, std::vector< T * > &pbuf, const MPI_Datatype Ttype) const
Definition: MPIWrapper.hpp:676
MPIComm(MPIComm &&c) noexcept
Definition: MPIWrapper.hpp:225
int size() const
Definition: MPIWrapper.hpp:282
int rank() const
Definition: MPIWrapper.hpp:271
MPIComm & operator=(const MPIComm &c)
Definition: MPIWrapper.hpp:241
void all_to_all_v(std::vector< std::vector< T >> &sbuf, std::vector< T, A > &rbuf, std::vector< T * > &pbuf) const
Definition: MPIWrapper.hpp:632
MPIComm sub_self(int p) const
Definition: MPIWrapper.hpp:782
static void control_start(const std::string &name)
Definition: MPIWrapper.hpp:797
MPIComm(MPI_Comm c)
Definition: MPIWrapper.hpp:209
std::vector< T > recv(int src, int tag) const
Definition: MPIWrapper.hpp:455
MPIRequest isend(const std::vector< T > &sbuf, int dest, int tag) const
Definition: MPIWrapper.hpp:375
void all_reduce(T *t, int ssize, MPI_Op op) const
Definition: MPIWrapper.hpp:556
Wrapper around an MPI_Request object.
Definition: MPIWrapper.hpp:124
MPIRequest(const MPIRequest &)=delete
void wait()
Definition: MPIWrapper.hpp:157
MPIRequest(MPIRequest &&)=default
MPIRequest()
Definition: MPIWrapper.hpp:129
MPIRequest & operator=(const MPIRequest &)=delete
MPIRequest & operator=(MPIRequest &&)=default
Definition: StrumpackOptions.hpp:42
MPI_Datatype mpi_type()
Definition: MPIWrapper.hpp:60
MPI_Datatype mpi_type< int >()
Definition: MPIWrapper.hpp:67
void wait_all(std::vector< MPIRequest > &reqs)
Definition: MPIWrapper.hpp:173
MPI_Datatype mpi_type< long >()
Definition: MPIWrapper.hpp:69
MPI_Datatype mpi_type< float >()
Definition: MPIWrapper.hpp:75
int mpi_rank(MPI_Comm c=MPI_COMM_WORLD)
Definition: MPIWrapper.hpp:830
MPI_Datatype mpi_type< long long int >()
Definition: MPIWrapper.hpp:73
MPI_Datatype mpi_type< bool >()
Definition: MPIWrapper.hpp:65
void copy(std::size_t m, std::size_t n, const DenseMatrix< scalar_from_t > &a, std::size_t ia, std::size_t ja, DenseMatrix< scalar_to_t > &b, std::size_t ib, std::size_t jb)
Definition: DenseMatrix.hpp:1231
MPI_Datatype mpi_type< double >()
Definition: MPIWrapper.hpp:77
int mpi_nprocs(MPI_Comm c=MPI_COMM_WORLD)
Definition: MPIWrapper.hpp:844
MPI_Datatype mpi_type< unsigned long >()
Definition: MPIWrapper.hpp:71
MPI_Datatype mpi_type< char >()
Definition: MPIWrapper.hpp:62