33 #ifndef STRUMPACK_MPI_WRAPPER_HPP
34 #define STRUMPACK_MPI_WRAPPER_HPP
44 #define OMPI_SKIP_MPICXX 1
48 #include "Triplet.hpp"
76 template<>
inline MPI_Datatype mpi_type<std::complex<float>>() {
return MPI_CXX_FLOAT_COMPLEX; }
78 template<>
inline MPI_Datatype mpi_type<std::complex<double>>() {
return MPI_CXX_DOUBLE_COMPLEX; }
80 template<>
inline MPI_Datatype mpi_type<std::pair<int,int>>() {
return MPI_2INT; }
83 template<>
inline MPI_Datatype mpi_type<std::pair<long int,long int>>() {
84 static MPI_Datatype l_l_mpi_type = MPI_DATATYPE_NULL;
85 if (l_l_mpi_type == MPI_DATATYPE_NULL) {
87 (2, strumpack::mpi_type<long int>(), &l_l_mpi_type);
88 MPI_Type_commit(&l_l_mpi_type);
93 template<>
inline MPI_Datatype mpi_type<std::pair<long long int,long long int>>() {
94 static MPI_Datatype ll_ll_mpi_type = MPI_DATATYPE_NULL;
95 if (ll_ll_mpi_type == MPI_DATATYPE_NULL) {
98 MPI_Type_commit(&ll_ll_mpi_type);
100 return ll_ll_mpi_type;
125 req_ = std::unique_ptr<MPI_Request>(
new MPI_Request());
152 void wait() { MPI_Wait(req_.get(), MPI_STATUS_IGNORE); }
155 std::unique_ptr<MPI_Request> req_;
168 inline void wait_all(std::vector<MPIRequest>& reqs) {
169 for (
auto& r : reqs) r.wait();
173 inline void wait_all(std::vector<MPI_Request>& reqs) {
174 MPI_Waitall(reqs.size(), reqs.data(), MPI_STATUSES_IGNORE);
228 if (comm_ != MPI_COMM_NULL && comm_ != MPI_COMM_WORLD)
229 MPI_Comm_free(&comm_);
238 if (
this != &c) duplicate(c.
comm());
250 c.comm_ = MPI_COMM_NULL;
257 MPI_Comm
comm()
const {
return comm_; }
262 bool is_null()
const {
return comm_ == MPI_COMM_NULL; }
268 assert(comm_ != MPI_COMM_NULL);
270 MPI_Comm_rank(comm_, &r);
279 assert(comm_ != MPI_COMM_NULL);
281 MPI_Comm_size(comm_, &nprocs);
297 template<
typename T>
void
298 broadcast(std::vector<T>& sbuf)
const {
299 MPI_Bcast(sbuf.data(), sbuf.size(), mpi_type<T>(), 0, comm_);
301 template<
typename T>
void
302 broadcast_from(std::vector<T>& sbuf,
int src)
const {
303 MPI_Bcast(sbuf.data(), sbuf.size(), mpi_type<T>(), src, comm_);
306 template<
typename T, std::
size_t N>
void
307 broadcast(std::array<T,N>& sbuf)
const {
308 MPI_Bcast(sbuf.data(), sbuf.size(), mpi_type<T>(), 0, comm_);
311 template<
typename T>
void
312 broadcast(T& data)
const {
313 MPI_Bcast(&data, 1, mpi_type<T>(), 0, comm_);
315 template<
typename T>
void
316 broadcast_from(T& data,
int src)
const {
317 MPI_Bcast(&data, 1, mpi_type<T>(), src, comm_);
319 template<
typename T>
void
320 broadcast(T* sbuf, std::size_t ssize)
const {
321 MPI_Bcast(sbuf, ssize, mpi_type<T>(), 0, comm_);
323 template<
typename T>
void
324 broadcast_from(T* sbuf, std::size_t ssize,
int src)
const {
325 MPI_Bcast(sbuf, ssize, mpi_type<T>(), src, comm_);
329 void all_gather(T* buf, std::size_t rsize)
const {
331 (MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
332 buf, rsize, mpi_type<T>(), comm_);
336 void all_gather_v(T* buf,
const int* rcnts,
const int* displs)
const {
338 (MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, buf, rcnts, displs,
339 mpi_type<T>(), comm_);
360 MPI_Isend(
const_cast<T*
>(sbuf.data()), sbuf.size(), mpi_type<T>(),
361 dest, tag, comm_, req.req_.get());
378 void isend(
const std::vector<T>& sbuf,
int dest,
int tag,
379 MPI_Request* req)
const {
381 MPI_Isend(
const_cast<T*
>(sbuf.data()), sbuf.size(), mpi_type<T>(),
382 dest, tag, comm_, req);
386 void isend(
const T* sbuf, std::size_t ssize,
int dest,
387 int tag, MPI_Request* req)
const {
389 MPI_Isend(
const_cast<T*
>(sbuf), ssize, mpi_type<T>(),
390 dest, tag, comm_, req);
393 void send(
const T* sbuf, std::size_t ssize,
int dest,
int tag)
const {
395 MPI_Send(
const_cast<T*
>(sbuf), ssize, mpi_type<T>(), dest, tag, comm_);
399 void isend(
const T& buf,
int dest,
int tag, MPI_Request* req)
const {
401 MPI_Isend(
const_cast<T*
>(&buf), 1, mpi_type<T>(),
402 dest, tag, comm_, req);
419 void send(
const std::vector<T>& sbuf,
int dest,
int tag)
const {
421 MPI_Send(
const_cast<T*
>(sbuf.data()), sbuf.size(), mpi_type<T>(), dest, tag, comm_);
436 template<
typename T> std::vector<T>
recv(
int src,
int tag)
const {
438 MPI_Probe(src, tag, comm_, &stat);
440 MPI_Get_count(&stat, mpi_type<T>(), &msgsize);
442 std::vector<T> rbuf(msgsize);
443 MPI_Recv(rbuf.data(), msgsize, mpi_type<T>(), src, tag,
444 comm_, MPI_STATUS_IGNORE);
449 std::pair<int,std::vector<T>> recv_any_src(
int tag)
const {
451 MPI_Probe(MPI_ANY_SOURCE, tag, comm_, &stat);
453 MPI_Get_count(&stat, mpi_type<T>(), &msgsize);
454 std::vector<T> rbuf(msgsize);
455 MPI_Recv(rbuf.data(), msgsize, mpi_type<T>(), stat.MPI_SOURCE,
456 tag, comm_, MPI_STATUS_IGNORE);
457 return {stat.MPI_SOURCE, std::move(rbuf)};
460 template<
typename T>
T recv_one(
int src,
int tag)
const {
462 MPI_Recv(&t, 1, mpi_type<T>(), src, tag, comm_, MPI_STATUS_IGNORE);
467 void irecv(
const T* rbuf, std::size_t rsize,
int src,
468 int tag, MPI_Request* req)
const {
470 MPI_Irecv(
const_cast<T*
>(rbuf), rsize, mpi_type<T>(),
471 src, tag, comm_, req);
475 void recv(
const T* rbuf, std::size_t rsize,
int src,
int tag)
const {
478 MPI_Recv(
const_cast<T*
>(rbuf), rsize, mpi_type<T>(),
479 src, tag, comm_, &stat);
497 MPI_Allreduce(MPI_IN_PLACE, &t, 1, mpi_type<T>(), op, comm_);
515 template<
typename T>
T reduce(T t, MPI_Op op)
const {
517 MPI_Reduce(MPI_IN_PLACE, &t, 1, mpi_type<T>(), op, 0, comm_);
518 else MPI_Reduce(&t, &t, 1, mpi_type<T>(), op, 0, comm_);
537 template<
typename T>
void all_reduce(T* t,
int ssize, MPI_Op op)
const {
538 MPI_Allreduce(MPI_IN_PLACE, t, ssize, mpi_type<T>(), op, comm_);
541 template<
typename T>
void all_reduce(std::vector<T>& t, MPI_Op op)
const {
561 template<
typename T>
void
562 reduce(T* t,
int ssize, MPI_Op op,
int dest=0)
const {
564 MPI_Reduce(MPI_IN_PLACE, t, ssize, mpi_type<T>(), op, dest, comm_);
565 else MPI_Reduce(t, t, ssize, mpi_type<T>(), op, dest, comm_);
569 void all_to_all(
const T* sbuf,
int scnt, T* rbuf)
const {
571 (sbuf, scnt, mpi_type<T>(), rbuf, scnt, mpi_type<T>(), comm_);
574 template<
typename T,
typename A=std::allocator<T>> std::vector<T,A>
575 all_to_allv(
const T* sbuf,
int* scnts,
int* sdispls,
576 int* rcnts,
int* rdispls)
const {
577 std::size_t rsize = 0;
578 for (
int p=0; p<
size(); p++)
580 std::vector<T,A> rbuf(rsize);
582 (sbuf, scnts, sdispls, mpi_type<T>(),
583 rbuf.data(), rcnts, rdispls, mpi_type<T>(), comm_);
587 template<
typename T>
void
588 all_to_allv(
const T* sbuf,
int* scnts,
int* sdispls,
589 T* rbuf,
int* rcnts,
int* rdispls)
const {
591 (sbuf, scnts, sdispls, mpi_type<T>(),
592 rbuf, rcnts, rdispls, mpi_type<T>(), comm_);
612 template<
typename T,
typename A=std::allocator<T>>
void
613 all_to_all_v(std::vector<std::vector<T>>& sbuf, std::vector<T,A>& rbuf,
614 std::vector<T*>& pbuf)
const {
631 template<
typename T,
typename A=std::allocator<T>> std::vector<T,A>
633 std::vector<T,A> rbuf;
634 std::vector<T*> pbuf;
656 template<
typename T,
typename A=std::allocator<T>>
void
657 all_to_all_v(std::vector<std::vector<T>>& sbuf, std::vector<T,A>& rbuf,
658 std::vector<T*>& pbuf,
const MPI_Datatype Ttype)
const {
659 assert(sbuf.size() == std::size_t(
size()));
661 std::unique_ptr<int[]> iwork(
new int[4*P]);
662 auto ssizes = iwork.get();
663 auto rsizes = ssizes + P;
664 auto sdispl = ssizes + 2*P;
665 auto rdispl = ssizes + 3*P;
666 for (
int p=0; p<P; p++) {
668 static_cast<std::size_t
>(std::numeric_limits<int>::max())) {
669 std::cerr <<
"# ERROR: 32bit integer overflow in all_to_all_v!!"
673 ssizes[p] = sbuf[p].size();
677 std::size_t totssize = std::accumulate(ssizes, ssizes+P, std::size_t(0)),
678 totrsize = std::accumulate(rsizes, rsizes+P, std::size_t(0));
680 static_cast<std::size_t
>(std::numeric_limits<int>::max()) ||
682 static_cast<std::size_t
>(std::numeric_limits<int>::max())) {
688 rbuf.resize(totrsize);
689 std::unique_ptr<MPI_Request[]> reqs(
new MPI_Request[2*P]);
690 std::size_t displ = 0;
692 for (
int p=0; p<P; p++) {
693 pbuf[p] = rbuf.data() + displ;
694 MPI_Irecv(pbuf[p], rsizes[p], Ttype, p, 0, comm_, reqs.get()+p);
697 for (
int p=0; p<P; p++)
699 (sbuf[p].data(), ssizes[p], Ttype, p, 0, comm_, reqs.get()+P+p);
700 MPI_Waitall(2*P, reqs.get(), MPI_STATUSES_IGNORE);
701 std::vector<std::vector<T>>().swap(sbuf);
703 std::unique_ptr<T[]> sendbuf_(
new T[totssize]);
704 auto sendbuf = sendbuf_.get();
705 sdispl[0] = rdispl[0] = 0;
706 for (
int p=1; p<P; p++) {
707 sdispl[p] = sdispl[p-1] + ssizes[p-1];
708 rdispl[p] = rdispl[p-1] + rsizes[p-1];
710 for (
int p=0; p<P; p++)
711 std::copy(sbuf[p].begin(), sbuf[p].end(), sendbuf+sdispl[p]);
712 std::vector<std::vector<T>>().swap(sbuf);
713 rbuf.resize(totrsize);
714 MPI_Alltoallv(sendbuf, ssizes, sdispl, Ttype,
715 rbuf.data(), rsizes, rdispl, Ttype, comm_);
717 for (
int p=0; p<P; p++)
718 pbuf[p] = rbuf.data() + rdispl[p];
739 assert(P0 + P <=
size());
741 std::vector<int> sub_ranks(P);
742 for (
int i=0; i<P; i++)
743 sub_ranks[i] = P0 + i*stride;
744 MPI_Group group, sub_group;
745 MPI_Comm_group(comm_, &group);
746 MPI_Group_incl(group, P, sub_ranks.data(), &sub_group);
747 MPI_Comm_create(comm_, sub_group, &sub_comm.comm_);
748 MPI_Group_free(&group);
749 MPI_Group_free(&sub_group);
766 MPI_Group group, sub_group;
767 MPI_Comm_group(comm_, &group);
768 MPI_Group_incl(group, 1, &p, &sub_group);
769 MPI_Comm_create(comm_, sub_group, &c0.comm_);
770 MPI_Group_free(&group);
771 MPI_Group_free(&sub_group);
779 MPI_Pcontrol(1, name.c_str());
785 MPI_Pcontrol(-1, name.c_str());
788 static bool initialized() {
790 MPI_Initialized(&flag);
791 return static_cast<bool>(flag);
795 MPI_Comm comm_ = MPI_COMM_WORLD;
797 void duplicate(MPI_Comm c) {
798 if (c == MPI_COMM_NULL) comm_ = c;
799 else MPI_Comm_dup(c, &comm_);
812 assert(c != MPI_COMM_NULL);
814 MPI_Comm_rank(c, &rank);
826 assert(c != MPI_COMM_NULL);
828 MPI_Comm_size(c, &nprocs);
834 #endif // STRUMPACK_MPI_WRAPPER_HPP