Skip to content

Commit

Permalink
Stream-ordered wait{_any,_all}, fix wait_any implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed Nov 5, 2024
1 parent b153676 commit 4bc3385
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 10 deletions.
54 changes: 44 additions & 10 deletions src/KokkosComm/mpi/req.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,34 +64,68 @@ class Req<Mpi> {
private:
std::shared_ptr<Record> record_;

template <KokkosExecutionSpace ExecSpace>
friend void wait(const ExecSpace &space, Req<Mpi> req);
friend void wait(Req<Mpi> req);
template <KokkosExecutionSpace ExecSpace>
friend void wait_all(const ExecSpace &space, std::vector<Req<Mpi>> &reqs);
friend void wait_all(std::vector<Req<Mpi>> &reqs);
template <KokkosExecutionSpace ExecSpace>
friend int wait_any(const ExecSpace &space, std::vector<Req<Mpi>> &reqs);
friend int wait_any(std::vector<Req<Mpi>> &reqs);
};

inline void wait(Req<Mpi> req) {
template <KokkosExecutionSpace ExecSpace>
void wait(const ExecSpace &space, Req<Mpi> req) {
/* Semantically this only guarantees that `space` is waiting for request to complete. For the MPI host API, we have no
* choice but to fence the space before waiting on the host.*/
space.fence();
MPI_Wait(&req.mpi_request(), MPI_STATUS_IGNORE);
for (auto &f : req.record_->postWaits_) {
f();
}
req.record_->postWaits_.clear();
}

inline void wait_all(std::vector<Req<Mpi>> &reqs) {
inline void wait(Req<Mpi> req) { wait(Kokkos::DefaultExecutionSpace(), req); }

template <KokkosExecutionSpace ExecSpace>
void wait_all(const ExecSpace &space, std::vector<Req<Mpi>> &reqs) {
space.fence();
for (Req<Mpi> &req : reqs) {
wait(req);
MPI_Wait(&req.mpi_request(), MPI_STATUS_IGNORE);
for (auto &f : req.record_->postWaits_) {
f();
}
req.record_->postWaits_.clear();
}
}

inline int wait_any(std::vector<Req<Mpi>> &reqs) {
for (size_t i = 0; i < reqs.size(); ++i) {
int completed;
MPI_Test(&(reqs[i].mpi_request()), &completed, MPI_STATUS_IGNORE);
if (completed) {
return true;
inline void wait_all(std::vector<Req<Mpi>> &reqs) { wait_all(Kokkos::DefaultExecutionSpace(), reqs); }

template <KokkosExecutionSpace ExecSpace>
int wait_any(const ExecSpace &space, std::vector<Req<Mpi>> &reqs) {
if (reqs.empty()) {
return -1;
}

space.fence();
while (true) { // wait until something is done
for (size_t i = 0; i < reqs.size(); ++i) {
int completed;
Req<Mpi> &req = reqs[i];
MPI_Test(&(req.mpi_request()), &completed, MPI_STATUS_IGNORE);
if (completed) {
for (auto &f : req.record_->postWaits_) {
f();
}
req.record_->postWaits_.clear();
return i;
}
}
}
return false;
}

inline int wait_any(std::vector<Req<Mpi>> &reqs) { return wait_any(Kokkos::DefaultExecutionSpace(), reqs); }

} // namespace KokkosComm
1 change: 1 addition & 0 deletions unit_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ target_sources(
mpi/test_alltoall.cpp
mpi/test_reduce.cpp
mpi/test_allgather.cpp
mpi/test_waitany.cpp
)
target_link_libraries(
test-main
Expand Down
105 changes: 105 additions & 0 deletions unit_tests/mpi/test_waitany.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#include <gtest/gtest.h>
#include <type_traits>
#include <algorithm> // iota
#include <random>

#include "KokkosComm/KokkosComm.hpp"

namespace {

using namespace KokkosComm::mpi;

template <typename T>
class MpiWaitAny : public testing::Test {
public:
using Scalar = T;
};

using ScalarTypes = ::testing::Types<int, double, Kokkos::complex<float>>;
TYPED_TEST_SUITE(MpiWaitAny, ScalarTypes);

template <KokkosComm::KokkosExecutionSpace ExecSpace, typename Scalar>
void wait_any() {
using TestView = Kokkos::View<Scalar *>;

int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
if (size < 2) {
GTEST_SKIP() << "Requires >= 2 ranks (" << size << " provided)";
}

constexpr size_t numMsg = 128;
ExecSpace space;
std::vector<KokkosComm::Req<>> reqs;
std::vector<TestView> views;

for (size_t i = 0; i < numMsg; ++i) {
views.push_back(TestView(std::to_string(i), i));
}

constexpr unsigned int SEED = 31337;
std::random_device rd;
std::mt19937 g(SEED);

// random send/recv order
std::vector<size_t> order(numMsg);
std::iota(order.begin(), order.end(), size_t(0));
std::shuffle(order.begin(), order.end(), g);

KokkosComm::Handle<ExecSpace, KokkosComm::Mpi> h(space, MPI_COMM_WORLD);

if (0 == rank) {
constexpr int dst = 1;

for (size_t i : order) {
reqs.push_back(KokkosComm::send(h, views[i], dst));
}

for (size_t i = 0; i < numMsg; ++i) {
reqs.erase(reqs.begin() + KokkosComm::wait_any(reqs));
}
} else if (1 == rank) {
constexpr int src = 0;

for (size_t i : order) {
reqs.push_back(KokkosComm::recv(h, views[i], src));
}

for (size_t i = 0; i < numMsg; ++i) {
reqs.erase(reqs.begin() + KokkosComm::wait_any(reqs));
}
}
}

// TODO: test call on no requests

TYPED_TEST(MpiWaitAny, default_execution_space) {
wait_any<Kokkos::DefaultExecutionSpace, typename TestFixture::Scalar>();
}

TYPED_TEST(MpiWaitAny, default_host_execution_space) {
if constexpr (std::is_same_v<Kokkos::DefaultHostExecutionSpace, Kokkos::DefaultExecutionSpace>) {
GTEST_SKIP() << "Skipping test: DefaultHostExecSpace = DefaultExecSpace";
} else {
wait_any<Kokkos::DefaultHostExecutionSpace, typename TestFixture::Scalar>();
}
}

} // namespace

0 comments on commit 4bc3385

Please sign in to comment.