Skip to content

Commit

Permalink
Rework wait: own files, partial struct specialization
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed Nov 5, 2024
1 parent 4bc3385 commit 69f87e0
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 61 deletions.
2 changes: 2 additions & 0 deletions src/KokkosComm/KokkosComm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@
#include "mpi/isend.hpp"
#include "mpi/recv.hpp"
#include "mpi/reduce.hpp"
#include "mpi/impl/wait.hpp"
#else
#error at least one transport must be defined
#endif

#include "concepts.hpp"
#include "point_to_point.hpp"
#include "collective.hpp"
#include "wait.hpp"

#include <Kokkos_Core.hpp>

Expand Down
9 changes: 9 additions & 0 deletions src/KokkosComm/fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ template <KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
CommunicationSpace CommSpace = DefaultCommunicationSpace>
struct Barrier;

template <KokkosExecutionSpace ExecSpace, CommunicationSpace CommSpace>
struct Wait;

template <KokkosExecutionSpace ExecSpace, CommunicationSpace CommSpace>
struct WaitAll;

template <KokkosExecutionSpace ExecSpace, CommunicationSpace CommSpace>
struct WaitAny;

} // namespace Impl

} // namespace KokkosComm
79 changes: 79 additions & 0 deletions src/KokkosComm/mpi/impl/wait.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#pragma once

#include <KokkosComm/mpi/req.hpp>

namespace KokkosComm::Impl {

/* Enqueue a communication completion*/
template <KokkosExecutionSpace ExecSpace>
struct Wait<ExecSpace, Mpi> {
Wait(const ExecSpace &space, Req<Mpi> req) {
// ensure that the execution space has completed all work before completing the communication
space.fence();
MPI_Wait(&req.mpi_request(), MPI_STATUS_IGNORE);
for (auto &f : req.record_->postWaits_) {
f();
}
req.record_->postWaits_.clear();
}
};

template <KokkosExecutionSpace ExecSpace>
struct WaitAll<ExecSpace, Mpi> {
WaitAll(const ExecSpace &space, std::vector<Req<Mpi>> &reqs) {
// ensure that the execution space has completed all work before completing the communication
space.fence();
for (Req<Mpi> &req : reqs) {
MPI_Wait(&req.mpi_request(), MPI_STATUS_IGNORE);
for (auto &f : req.record_->postWaits_) {
f();
}
req.record_->postWaits_.clear();
}
}
};

/* Returns the index of the request that completed */
template <KokkosExecutionSpace ExecSpace>
struct WaitAny<ExecSpace, Mpi> {
static int execute(const ExecSpace &space, std::vector<Req<Mpi>> &reqs) {
if (reqs.empty()) {
return -1;
}

// ensure that the execution space has completed all work before completing the communication
space.fence();
while (true) { // wait until something is done
for (size_t i = 0; i < reqs.size(); ++i) {
int completed;
Req<Mpi> &req = reqs[i];
MPI_Test(&(req.mpi_request()), &completed, MPI_STATUS_IGNORE);
if (completed) {
for (auto &f : req.record_->postWaits_) {
f();
}
req.record_->postWaits_.clear();
return i;
}
}
}
}
};

} // namespace KokkosComm::Impl
68 changes: 7 additions & 61 deletions src/KokkosComm/mpi/req.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,68 +64,14 @@ class Req<Mpi> {
private:
std::shared_ptr<Record> record_;

template <KokkosExecutionSpace ExecSpace>
friend void wait(const ExecSpace &space, Req<Mpi> req);
friend void wait(Req<Mpi> req);
template <KokkosExecutionSpace ExecSpace>
friend void wait_all(const ExecSpace &space, std::vector<Req<Mpi>> &reqs);
friend void wait_all(std::vector<Req<Mpi>> &reqs);
template <KokkosExecutionSpace ExecSpace>
friend int wait_any(const ExecSpace &space, std::vector<Req<Mpi>> &reqs);
friend int wait_any(std::vector<Req<Mpi>> &reqs);
};

template <KokkosExecutionSpace ExecSpace>
void wait(const ExecSpace &space, Req<Mpi> req) {
/* Semantically this only guarantees that `space` is waiting for request to complete. For the MPI host API, we have no
* choice but to fence the space before waiting on the host.*/
space.fence();
MPI_Wait(&req.mpi_request(), MPI_STATUS_IGNORE);
for (auto &f : req.record_->postWaits_) {
f();
}
req.record_->postWaits_.clear();
}

inline void wait(Req<Mpi> req) { wait(Kokkos::DefaultExecutionSpace(), req); }

template <KokkosExecutionSpace ExecSpace>
void wait_all(const ExecSpace &space, std::vector<Req<Mpi>> &reqs) {
space.fence();
for (Req<Mpi> &req : reqs) {
MPI_Wait(&req.mpi_request(), MPI_STATUS_IGNORE);
for (auto &f : req.record_->postWaits_) {
f();
}
req.record_->postWaits_.clear();
}
}
template <KokkosExecutionSpace ExecSpace, CommunicationSpace CommSpace>
friend struct KokkosComm::Impl::Wait;

inline void wait_all(std::vector<Req<Mpi>> &reqs) { wait_all(Kokkos::DefaultExecutionSpace(), reqs); }
template <KokkosExecutionSpace ExecSpace, CommunicationSpace CommSpace>
friend struct KokkosComm::Impl::WaitAll;

template <KokkosExecutionSpace ExecSpace>
int wait_any(const ExecSpace &space, std::vector<Req<Mpi>> &reqs) {
if (reqs.empty()) {
return -1;
}

space.fence();
while (true) { // wait until something is done
for (size_t i = 0; i < reqs.size(); ++i) {
int completed;
Req<Mpi> &req = reqs[i];
MPI_Test(&(req.mpi_request()), &completed, MPI_STATUS_IGNORE);
if (completed) {
for (auto &f : req.record_->postWaits_) {
f();
}
req.record_->postWaits_.clear();
return i;
}
}
}
}

inline int wait_any(std::vector<Req<Mpi>> &reqs) { return wait_any(Kokkos::DefaultExecutionSpace(), reqs); }
template <KokkosExecutionSpace ExecSpace, CommunicationSpace CommSpace>
friend struct KokkosComm::Impl::WaitAny;
};

} // namespace KokkosComm
55 changes: 55 additions & 0 deletions src/KokkosComm/wait.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#pragma once

#include <Kokkos_Core.hpp>

#include "fwd.hpp"
#include "concepts.hpp"

namespace KokkosComm {

// FIXME: reverse order of these template params for automatic deduction
template <KokkosExecutionSpace ExecSpace, CommunicationSpace CommSpace>
void wait(const ExecSpace &space, Req<CommSpace> req) {
Impl::Wait<ExecSpace, CommSpace>(space, req);
}

template <KokkosExecutionSpace ExecSpace, CommunicationSpace CommSpace>
void wait_all(const ExecSpace &space, std::vector<Req<Mpi>> &reqs) {
Impl::WaitAll<ExecSpace, CommSpace>(space, reqs);
}

template <KokkosExecutionSpace ExecSpace, CommunicationSpace CommSpace>
int wait_any(const ExecSpace &space, std::vector<Req<Mpi>> &reqs) {
return Impl::WaitAny<ExecSpace, CommSpace>::execute(space, reqs);
}

template <CommunicationSpace CommSpace>
inline void wait(Req<CommSpace> req) {
return wait<Kokkos::DefaultExecutionSpace, CommSpace>(Kokkos::DefaultExecutionSpace{}, req);
}
template <CommunicationSpace CommSpace>
inline void wait_all(std::vector<Req<CommSpace>> &reqs) {
wait_all<Kokkos::DefaultExecutionSpace, CommSpace>(Kokkos::DefaultExecutionSpace{}, reqs);
}
template <CommunicationSpace CommSpace>
inline int wait_any(std::vector<Req<CommSpace>> &reqs) {
return wait_any<Kokkos::DefaultExecutionSpace, CommSpace>(Kokkos::DefaultExecutionSpace{}, reqs);
}

} // namespace KokkosComm

0 comments on commit 69f87e0

Please sign in to comment.