From 200ebaea119aac51e22f1f66467fd35c0e3b372c Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 17 May 2024 09:32:31 -0600 Subject: [PATCH] mpi: allgather: low-level API, space handling --- src/impl/KokkosComm_allgather.hpp | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/impl/KokkosComm_allgather.hpp b/src/impl/KokkosComm_allgather.hpp index f09b6251..75e41f61 100644 --- a/src/impl/KokkosComm_allgather.hpp +++ b/src/impl/KokkosComm_allgather.hpp @@ -26,8 +26,9 @@ #include "KokkosComm_types.hpp" namespace KokkosComm::Impl { -template -void allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, MPI_Comm comm) { + +template +void allgather(const SendView &sv, const RecvView &rv, MPI_Comm comm) { Kokkos::Tools::pushRegion("KokkosComm::Impl::allgather"); using ST = KokkosComm::Traits; @@ -38,12 +39,30 @@ void allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, M static_assert(ST::rank() <= 1, "allgather for SendView::rank > 1 not supported"); static_assert(RT::rank() <= 1, "allgather for RecvView::rank > 1 not supported"); - if (KokkosComm::PackTraits::needs_pack(sv) || KokkosComm::PackTraits::needs_pack(rv)) { + if (!ST::is_contiguous(sv)) { + throw std::runtime_error("low-level allgather requires contiguous send view"); + } + if (!RT::is_contiguous(rv)) { + throw std::runtime_error("low-level allgather requires contiguous recv view"); + } + const int count = ST::span(sv); // all ranks send/recv same count + MPI_Allgather(ST::data_handle(sv), count, mpi_type_v, RT::data_handle(rv), count, mpi_type_v, + comm); + + Kokkos::Tools::popRegion(); +} + +template +void allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, MPI_Comm comm) { + Kokkos::Tools::pushRegion("KokkosComm::Impl::allgather"); + using ST = KokkosComm::Traits; + using RT = KokkosComm::Traits; + + if (ST::needs_pack(sv) || RT::needs_pack(rv)) { throw std::runtime_error("allgather for non-contiguous views not implemented"); } else { - const int count = ST::span(sv); // all ranks send/recv same count - MPI_Allgather(ST::data_handle(sv), count, mpi_type_v, RT::data_handle(rv), count, - mpi_type_v, comm); + space.fence(); // work in space may have been used to produce send view data + allgather(sv, rv, comm); } Kokkos::Tools::popRegion();