33 #ifndef STRUMPACK_MPI_WRAPPER_HPP
34 #define STRUMPACK_MPI_WRAPPER_HPP
44 #define OMPI_SKIP_MPICXX 1
48 #include "Triplet.hpp"
76 template<>
inline MPI_Datatype mpi_type<std::complex<float>>() {
return MPI_CXX_FLOAT_COMPLEX; }
78 template<>
inline MPI_Datatype mpi_type<std::complex<double>>() {
return MPI_CXX_DOUBLE_COMPLEX; }
80 template<>
inline MPI_Datatype mpi_type<std::pair<int,int>>() {
return MPI_2INT; }
83 template<>
inline MPI_Datatype mpi_type<std::pair<long int,long int>>() {
84 static MPI_Datatype l_l_mpi_type = MPI_DATATYPE_NULL;
85 if (l_l_mpi_type == MPI_DATATYPE_NULL) {
87 (2, strumpack::mpi_type<long int>(), &l_l_mpi_type);
88 MPI_Type_commit(&l_l_mpi_type);
93 template<>
inline MPI_Datatype mpi_type<std::pair<long long int,long long int>>() {
94 static MPI_Datatype ll_ll_mpi_type = MPI_DATATYPE_NULL;
95 if (ll_ll_mpi_type == MPI_DATATYPE_NULL) {
98 MPI_Type_commit(&ll_ll_mpi_type);
100 return ll_ll_mpi_type;
125 req_ = std::unique_ptr<MPI_Request>(
new MPI_Request());
152 void wait() { MPI_Wait(req_.get(), MPI_STATUS_IGNORE); }
155 std::unique_ptr<MPI_Request> req_;
168 inline void wait_all(std::vector<MPIRequest>& reqs) {
169 for (
auto& r : reqs) r.wait();
173 inline void wait_all(std::vector<MPI_Request>& reqs) {
174 MPI_Waitall(reqs.size(), reqs.data(), MPI_STATUSES_IGNORE);
228 if (comm_ != MPI_COMM_NULL && comm_ != MPI_COMM_WORLD)
229 MPI_Comm_free(&comm_);
238 if (
this != &c) duplicate(c.
comm());
250 c.comm_ = MPI_COMM_NULL;
257 MPI_Comm
comm()
const {
return comm_; }
262 bool is_null()
const {
return comm_ == MPI_COMM_NULL; }
268 assert(comm_ != MPI_COMM_NULL);
270 MPI_Comm_rank(comm_, &r);
279 assert(comm_ != MPI_COMM_NULL);
281 MPI_Comm_size(comm_, &nprocs);
297 template<
typename T>
void
298 broadcast(std::vector<T>& sbuf)
const {
299 MPI_Bcast(sbuf.data(), sbuf.size(), mpi_type<T>(), 0, comm_);
301 template<
typename T>
void
302 broadcast_from(std::vector<T>& sbuf,
int src)
const {
303 MPI_Bcast(sbuf.data(), sbuf.size(), mpi_type<T>(), src, comm_);
306 template<
typename T, std::
size_t N>
void
307 broadcast(std::array<T,N>& sbuf)
const {
308 MPI_Bcast(sbuf.data(), sbuf.size(), mpi_type<T>(), 0, comm_);
311 template<
typename T>
void
312 broadcast(T& data)
const {
313 MPI_Bcast(&data, 1, mpi_type<T>(), 0, comm_);
315 template<
typename T>
void
316 broadcast_from(T& data,
int src)
const {
317 MPI_Bcast(&data, 1, mpi_type<T>(), src, comm_);
319 template<
typename T>
void
320 broadcast(T* sbuf, std::size_t ssize)
const {
321 MPI_Bcast(sbuf, ssize, mpi_type<T>(), 0, comm_);
323 template<
typename T>
void
324 broadcast_from(T* sbuf, std::size_t ssize,
int src)
const {
325 MPI_Bcast(sbuf, ssize, mpi_type<T>(), src, comm_);
329 void all_gather(T* buf, std::size_t rsize)
const {
331 (MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
332 buf, rsize, mpi_type<T>(), comm_);
336 void all_gather_v(T* buf,
const int* rcnts,
const int* displs)
const {
338 (MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, buf, rcnts, displs,
339 mpi_type<T>(), comm_);
360 MPI_Isend(
const_cast<T*
>(sbuf.data()), sbuf.size(), mpi_type<T>(),
361 dest, tag, comm_, req.req_.get());
378 void isend(
const std::vector<T>& sbuf,
int dest,
int tag,
379 MPI_Request* req)
const {
381 MPI_Isend(
const_cast<T*
>(sbuf.data()), sbuf.size(), mpi_type<T>(),
382 dest, tag, comm_, req);
386 void isend(
const T* sbuf, std::size_t ssize,
int dest,
387 int tag, MPI_Request* req)
const {
389 MPI_Isend(
const_cast<T*
>(sbuf), ssize, mpi_type<T>(),
390 dest, tag, comm_, req);
394 void isend(
const T& buf,
int dest,
int tag, MPI_Request* req)
const {
396 MPI_Isend(
const_cast<T*
>(&buf), 1, mpi_type<T>(),
397 dest, tag, comm_, req);
414 void send(
const std::vector<T>& sbuf,
int dest,
int tag)
const {
416 MPI_Send(
const_cast<T*
>(sbuf.data()), sbuf.size(), mpi_type<T>(), dest, tag, comm_);
431 template<
typename T> std::vector<T>
recv(
int src,
int tag)
const {
433 MPI_Probe(src, tag, comm_, &stat);
435 MPI_Get_count(&stat, mpi_type<T>(), &msgsize);
437 std::vector<T> rbuf(msgsize);
438 MPI_Recv(rbuf.data(), msgsize, mpi_type<T>(), src, tag,
439 comm_, MPI_STATUS_IGNORE);
444 std::pair<int,std::vector<T>> recv_any_src(
int tag)
const {
446 MPI_Probe(MPI_ANY_SOURCE, tag, comm_, &stat);
448 MPI_Get_count(&stat, mpi_type<T>(), &msgsize);
449 std::vector<T> rbuf(msgsize);
450 MPI_Recv(rbuf.data(), msgsize, mpi_type<T>(), stat.MPI_SOURCE,
451 tag, comm_, MPI_STATUS_IGNORE);
452 return {stat.MPI_SOURCE, std::move(rbuf)};
455 template<
typename T>
T recv_one(
int src,
int tag)
const {
457 MPI_Recv(&t, 1, mpi_type<T>(), src, tag, comm_, MPI_STATUS_IGNORE);
462 void irecv(
const T* rbuf, std::size_t rsize,
int src,
463 int tag, MPI_Request* req)
const {
465 MPI_Irecv(
const_cast<T*
>(rbuf), rsize, mpi_type<T>(),
466 src, tag, comm_, req);
484 MPI_Allreduce(MPI_IN_PLACE, &t, 1, mpi_type<T>(), op, comm_);
502 template<
typename T>
T reduce(T t, MPI_Op op)
const {
504 MPI_Reduce(MPI_IN_PLACE, &t, 1, mpi_type<T>(), op, 0, comm_);
505 else MPI_Reduce(&t, &t, 1, mpi_type<T>(), op, 0, comm_);
524 template<
typename T>
void all_reduce(T* t,
int ssize, MPI_Op op)
const {
525 MPI_Allreduce(MPI_IN_PLACE, t, ssize, mpi_type<T>(), op, comm_);
528 template<
typename T>
void all_reduce(std::vector<T>& t, MPI_Op op)
const {
547 template<
typename T>
void reduce(T* t,
int ssize, MPI_Op op)
const {
549 MPI_Reduce(MPI_IN_PLACE, t, ssize, mpi_type<T>(), op, 0, comm_);
550 else MPI_Reduce(t, t, ssize, mpi_type<T>(), op, 0, comm_);
554 void all_to_all(
const T* sbuf,
int scnt, T* rbuf)
const {
556 (sbuf, scnt, mpi_type<T>(), rbuf, scnt, mpi_type<T>(), comm_);
559 template<
typename T,
typename A=std::allocator<T>> std::vector<T,A>
560 all_to_allv(
const T* sbuf,
int* scnts,
int* sdispls,
561 int* rcnts,
int* rdispls)
const {
562 std::size_t rsize = 0;
563 for (
int p=0; p<
size(); p++)
565 std::vector<T,A> rbuf(rsize);
567 (sbuf, scnts, sdispls, mpi_type<T>(),
568 rbuf.data(), rcnts, rdispls, mpi_type<T>(), comm_);
572 template<
typename T>
void
573 all_to_allv(
const T* sbuf,
int* scnts,
int* sdispls,
574 T* rbuf,
int* rcnts,
int* rdispls)
const {
576 (sbuf, scnts, sdispls, mpi_type<T>(),
577 rbuf, rcnts, rdispls, mpi_type<T>(), comm_);
597 template<
typename T,
typename A=std::allocator<T>>
void
598 all_to_all_v(std::vector<std::vector<T>>& sbuf, std::vector<T,A>& rbuf,
599 std::vector<T*>& pbuf)
const {
616 template<
typename T,
typename A=std::allocator<T>> std::vector<T,A>
618 std::vector<T,A> rbuf;
619 std::vector<T*> pbuf;
641 template<
typename T,
typename A=std::allocator<T>>
void
642 all_to_all_v(std::vector<std::vector<T>>& sbuf, std::vector<T,A>& rbuf,
643 std::vector<T*>& pbuf,
const MPI_Datatype Ttype)
const {
644 assert(sbuf.size() == std::size_t(
size()));
646 std::unique_ptr<int[]> iwork(
new int[4*P]);
647 auto ssizes = iwork.get();
648 auto rsizes = ssizes + P;
649 auto sdispl = ssizes + 2*P;
650 auto rdispl = ssizes + 3*P;
651 for (
int p=0; p<P; p++) {
653 static_cast<std::size_t
>(std::numeric_limits<int>::max())) {
654 std::cerr <<
"# ERROR: 32bit integer overflow in all_to_all_v!!"
658 ssizes[p] = sbuf[p].size();
662 std::size_t totssize = std::accumulate(ssizes, ssizes+P, std::size_t(0)),
663 totrsize = std::accumulate(rsizes, rsizes+P, std::size_t(0));
665 static_cast<std::size_t
>(std::numeric_limits<int>::max()) ||
667 static_cast<std::size_t
>(std::numeric_limits<int>::max())) {
673 rbuf.resize(totrsize);
674 std::unique_ptr<MPI_Request[]> reqs(
new MPI_Request[2*P]);
675 std::size_t displ = 0;
677 for (
int p=0; p<P; p++) {
678 pbuf[p] = rbuf.data() + displ;
679 MPI_Irecv(pbuf[p], rsizes[p], Ttype, p, 0, comm_, reqs.get()+p);
682 for (
int p=0; p<P; p++)
684 (sbuf[p].data(), ssizes[p], Ttype, p, 0, comm_, reqs.get()+P+p);
685 MPI_Waitall(2*P, reqs.get(), MPI_STATUSES_IGNORE);
686 std::vector<std::vector<T>>().swap(sbuf);
688 std::unique_ptr<T[]> sendbuf_(
new T[totssize]);
689 auto sendbuf = sendbuf_.get();
690 sdispl[0] = rdispl[0] = 0;
691 for (
int p=1; p<P; p++) {
692 sdispl[p] = sdispl[p-1] + ssizes[p-1];
693 rdispl[p] = rdispl[p-1] + rsizes[p-1];
695 for (
int p=0; p<P; p++)
696 std::copy(sbuf[p].begin(), sbuf[p].end(), sendbuf+sdispl[p]);
697 std::vector<std::vector<T>>().swap(sbuf);
698 rbuf.resize(totrsize);
699 MPI_Alltoallv(sendbuf, ssizes, sdispl, Ttype,
700 rbuf.data(), rsizes, rdispl, Ttype, comm_);
702 for (
int p=0; p<P; p++)
703 pbuf[p] = rbuf.data() + rdispl[p];
728 assert(P0 + P <=
size());
730 std::vector<int> sub_ranks(P);
731 for (
int i=0; i<P; i++)
732 sub_ranks[i] = P0 + i*stride;
733 MPI_Group group, sub_group;
734 MPI_Comm_group(comm_, &group);
735 MPI_Group_incl(group, P, sub_ranks.data(), &sub_group);
736 MPI_Comm_create(comm_, sub_group, &sub_comm.comm_);
737 MPI_Group_free(&group);
738 MPI_Group_free(&sub_group);
755 MPI_Group group, sub_group;
756 MPI_Comm_group(comm_, &group);
757 MPI_Group_incl(group, 1, &p, &sub_group);
758 MPI_Comm_create(comm_, sub_group, &c0.comm_);
759 MPI_Group_free(&group);
760 MPI_Group_free(&sub_group);
768 MPI_Pcontrol(1, name.c_str());
774 MPI_Pcontrol(-1, name.c_str());
777 static bool initialized() {
779 MPI_Initialized(&flag);
780 return static_cast<bool>(flag);
784 MPI_Comm comm_ = MPI_COMM_WORLD;
786 void duplicate(MPI_Comm c) {
787 if (c == MPI_COMM_NULL) comm_ = c;
788 else MPI_Comm_dup(c, &comm_);
801 assert(c != MPI_COMM_NULL);
803 MPI_Comm_rank(c, &rank);
815 assert(c != MPI_COMM_NULL);
817 MPI_Comm_size(c, &nprocs);
823 #endif // STRUMPACK_MPI_WRAPPER_HPP