diff --git a/src/KokkosComm/mpi/req.hpp b/src/KokkosComm/mpi/req.hpp index 72d4b1cb..2023f915 100644 --- a/src/KokkosComm/mpi/req.hpp +++ b/src/KokkosComm/mpi/req.hpp @@ -64,12 +64,22 @@ class Req { private: std::shared_ptr record_; + template + friend void wait(const ExecSpace &space, Req req); friend void wait(Req req); + template + friend void wait_all(const ExecSpace &space, std::vector> &reqs); friend void wait_all(std::vector> &reqs); + template + friend int wait_any(const ExecSpace &space, std::vector> &reqs); friend int wait_any(std::vector> &reqs); }; -inline void wait(Req req) { +template +void wait(const ExecSpace &space, Req req) { + /* Semantically this only guarantees that `space` is waiting for request to complete. For the MPI host API, we have no + * choice but to fence the space before waiting on the host.*/ + space.fence(); MPI_Wait(&req.mpi_request(), MPI_STATUS_IGNORE); for (auto &f : req.record_->postWaits_) { f(); @@ -77,21 +87,45 @@ inline void wait(Req req) { req.record_->postWaits_.clear(); } -inline void wait_all(std::vector> &reqs) { +inline void wait(Req req) { wait(Kokkos::DefaultExecutionSpace(), req); } + +template +void wait_all(const ExecSpace &space, std::vector> &reqs) { + space.fence(); for (Req &req : reqs) { - wait(req); + MPI_Wait(&req.mpi_request(), MPI_STATUS_IGNORE); + for (auto &f : req.record_->postWaits_) { + f(); + } + req.record_->postWaits_.clear(); } } -inline int wait_any(std::vector> &reqs) { - for (size_t i = 0; i < reqs.size(); ++i) { - int completed; - MPI_Test(&(reqs[i].mpi_request()), &completed, MPI_STATUS_IGNORE); - if (completed) { - return true; +inline void wait_all(std::vector> &reqs) { wait_all(Kokkos::DefaultExecutionSpace(), reqs); } + +template +int wait_any(const ExecSpace &space, std::vector> &reqs) { + if (reqs.empty()) { + return -1; + } + + space.fence(); + while (true) { // wait until something is done + for (size_t i = 0; i < reqs.size(); ++i) { + int completed; + Req &req = reqs[i]; + MPI_Test(&(req.mpi_request()), &completed, MPI_STATUS_IGNORE); + if (completed) { + for (auto &f : req.record_->postWaits_) { + f(); + } + req.record_->postWaits_.clear(); + return i; + } } } - return false; } +inline int wait_any(std::vector> &reqs) { return wait_any(Kokkos::DefaultExecutionSpace(), reqs); } + } // namespace KokkosComm \ No newline at end of file diff --git a/unit_tests/CMakeLists.txt b/unit_tests/CMakeLists.txt index 1a4ae48c..c70e2a67 100644 --- a/unit_tests/CMakeLists.txt +++ b/unit_tests/CMakeLists.txt @@ -86,6 +86,7 @@ target_sources( mpi/test_alltoall.cpp mpi/test_reduce.cpp mpi/test_allgather.cpp + mpi/test_waitany.cpp ) target_link_libraries( test-main diff --git a/unit_tests/mpi/test_waitany.cpp b/unit_tests/mpi/test_waitany.cpp new file mode 100644 index 00000000..3bcef020 --- /dev/null +++ b/unit_tests/mpi/test_waitany.cpp @@ -0,0 +1,105 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#include +#include +#include // iota +#include + +#include "KokkosComm/KokkosComm.hpp" + +namespace { + +using namespace KokkosComm::mpi; + +template +class MpiWaitAny : public testing::Test { + public: + using Scalar = T; +}; + +using ScalarTypes = ::testing::Types>; +TYPED_TEST_SUITE(MpiWaitAny, ScalarTypes); + +template +void wait_any() { + using TestView = Kokkos::View; + + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + if (size < 2) { + GTEST_SKIP() << "Requires >= 2 ranks (" << size << " provided)"; + } + + constexpr size_t numMsg = 128; + ExecSpace space; + std::vector> reqs; + std::vector views; + + for (size_t i = 0; i < numMsg; ++i) { + views.push_back(TestView(std::to_string(i), i)); + } + + constexpr unsigned int SEED = 31337; + std::random_device rd; + std::mt19937 g(SEED); + + // random send/recv order + std::vector order(numMsg); + std::iota(order.begin(), order.end(), size_t(0)); + std::shuffle(order.begin(), order.end(), g); + + KokkosComm::Handle h(space, MPI_COMM_WORLD); + + if (0 == rank) { + constexpr int dst = 1; + + for (size_t i : order) { + reqs.push_back(KokkosComm::send(h, views[i], dst)); + } + + for (size_t i = 0; i < numMsg; ++i) { + reqs.erase(reqs.begin() + KokkosComm::wait_any(reqs)); + } + } else if (1 == rank) { + constexpr int src = 0; + + for (size_t i : order) { + reqs.push_back(KokkosComm::recv(h, views[i], src)); + } + + for (size_t i = 0; i < numMsg; ++i) { + reqs.erase(reqs.begin() + KokkosComm::wait_any(reqs)); + } + } +} + +// TODO: test call on no requests + +TYPED_TEST(MpiWaitAny, default_execution_space) { + wait_any(); +} + +TYPED_TEST(MpiWaitAny, default_host_execution_space) { + if constexpr (std::is_same_v) { + GTEST_SKIP() << "Skipping test: DefaultHostExecSpace = DefaultExecSpace"; + } else { + wait_any(); + } +} + +} // namespace