Skip to content

Commit

Permalink
mpi: allgather: low-level API, space handling
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed May 17, 2024
1 parent bd4f7dd commit 200ebae
Showing 1 changed file with 25 additions and 6 deletions.
31 changes: 25 additions & 6 deletions src/impl/KokkosComm_allgather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
#include "KokkosComm_types.hpp"

namespace KokkosComm::Impl {
template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvView>
void allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, MPI_Comm comm) {

template <KokkosView SendView, KokkosView RecvView>
void allgather(const SendView &sv, const RecvView &rv, MPI_Comm comm) {
Kokkos::Tools::pushRegion("KokkosComm::Impl::allgather");

using ST = KokkosComm::Traits<SendView>;
Expand All @@ -38,12 +39,30 @@ void allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, M
static_assert(ST::rank() <= 1, "allgather for SendView::rank > 1 not supported");
static_assert(RT::rank() <= 1, "allgather for RecvView::rank > 1 not supported");

if (KokkosComm::PackTraits<SendView>::needs_pack(sv) || KokkosComm::PackTraits<RecvView>::needs_pack(rv)) {
if (!ST::is_contiguous(sv)) {
throw std::runtime_error("low-level allgather requires contiguous send view");
}
if (!RT::is_contiguous(rv)) {
throw std::runtime_error("low-level allgather requires contiguous recv view");
}
const int count = ST::span(sv); // all ranks send/recv same count
MPI_Allgather(ST::data_handle(sv), count, mpi_type_v<SendScalar>, RT::data_handle(rv), count, mpi_type_v<RecvScalar>,
comm);

Kokkos::Tools::popRegion();
}

template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvView>
void allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, MPI_Comm comm) {
Kokkos::Tools::pushRegion("KokkosComm::Impl::allgather");
using ST = KokkosComm::Traits<SendView>;
using RT = KokkosComm::Traits<RecvView>;

if (ST::needs_pack(sv) || RT::needs_pack(rv)) {
throw std::runtime_error("allgather for non-contiguous views not implemented");
} else {
const int count = ST::span(sv); // all ranks send/recv same count
MPI_Allgather(ST::data_handle(sv), count, mpi_type_v<SendScalar>, RT::data_handle(rv), count,
mpi_type_v<RecvScalar>, comm);
space.fence(); // work in space may have been used to produce send view data
allgather(sv, rv, comm);
}

Kokkos::Tools::popRegion();
Expand Down

0 comments on commit 200ebae

Please sign in to comment.