Skip to content

Commit

Permalink
Merge pull request #55 from cwpearson/mpi/irecv-low
Browse files Browse the repository at this point in the history
mpi: contiguous Irecv wrapper
  • Loading branch information
cwpearson authored May 17, 2024
2 parents 4223662 + 3b5e286 commit aebe2ab
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/KokkosComm_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct Traits {
/*! \brief This can be specialized to do custom behavior for a particular view*/
template <KokkosView View>
struct Traits<View> {
// product of extents is span
static bool is_contiguous(const View &v) { return v.span_is_contiguous(); }

static auto data_handle(const View &v) { return v.data(); }
Expand Down
47 changes: 47 additions & 0 deletions src/impl/KokkosComm_irecv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#pragma once

#include <memory>

#include <Kokkos_Core.hpp>

#include "KokkosComm_pack_traits.hpp"
#include "KokkosComm_traits.hpp"

// impl
#include "KokkosComm_include_mpi.hpp"

namespace KokkosComm::Impl {

// low-level API
template <KokkosView RecvView>
void irecv(RecvView &rv, int src, int tag, MPI_Comm comm, MPI_Request &req) {
Kokkos::Tools::pushRegion("KokkosComm::Impl::irecv");

using KCT = KokkosComm::Traits<RecvView>;

if (KCT::is_contiguous(rv)) {
using RecvScalar = typename RecvView::value_type;
MPI_Irecv(KCT::data_handle(rv), KCT::span(rv), mpi_type_v<RecvScalar>, src, tag, comm, &req);
} else {
throw std::runtime_error("Only contiguous irecv viewsupported");
}

Kokkos::Tools::popRegion();
}
} // namespace KokkosComm::Impl
1 change: 1 addition & 0 deletions unit_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ target_link_libraries(test-mpi MPI::MPI_CXX)
# Kokkos Comm tests
add_executable(test-main test_main.cpp
test_gtest_mpi.cpp
test_isendirecv.cpp
test_isendrecv.cpp
test_sendrecv.cpp
test_barrier.cpp
Expand Down
109 changes: 109 additions & 0 deletions unit_tests/test_isendirecv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#include <gtest/gtest.h>

#include "KokkosComm.hpp"
#include "KokkosComm_irecv.hpp"

#include "view_builder.hpp"

namespace {

template <typename T>
class IsendIrecv : public testing::Test {
public:
using Scalar = T;
};

using ScalarTypes =
::testing::Types<float, double, Kokkos::complex<float>, Kokkos::complex<double>, int, unsigned, int64_t, size_t>;
TYPED_TEST_SUITE(IsendIrecv, ScalarTypes);

template <KokkosComm::KokkosView View1D>
void test_1d(const View1D &a) {
static_assert(View1D::rank == 1, "");
using Scalar = typename View1D::non_const_value_type;

int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
if (size < 2) {
GTEST_SKIP() << "Requires >= 2 ranks (" << size << " provided)";
}

if (0 == rank) {
int dst = 1;
Kokkos::parallel_for(
a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; });
KokkosComm::Req req = KokkosComm::isend(Kokkos::DefaultExecutionSpace(), a, dst, 0, MPI_COMM_WORLD);
req.wait();
} else if (1 == rank) {
int src = 0;
MPI_Request req;
KokkosComm::Impl::irecv(a, src, 0, MPI_COMM_WORLD, req);
MPI_Wait(&req, MPI_STATUS_IGNORE);
int errs;
Kokkos::parallel_reduce(
a.extent(0), KOKKOS_LAMBDA(const int &i, int &lsum) { lsum += a(i) != Scalar(i); }, errs);
ASSERT_EQ(errs, 0);
}
}

template <KokkosComm::KokkosView View2D>
void test_2d(const View2D &a) {
static_assert(View2D::rank == 2, "");
using Scalar = typename View2D::non_const_value_type;

int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
if (size < 2) {
GTEST_SKIP() << "Requires >= 2 ranks (" << size << " provided)";
}

using Policy = Kokkos::MDRangePolicy<Kokkos::Rank<2>>;
Policy policy({0, 0}, {a.extent(0), a.extent(1)});

if (0 == rank) {
int dst = 1;
Kokkos::parallel_for(
policy, KOKKOS_LAMBDA(int i, int j) { a(i, j) = i * a.extent(0) + j; });
KokkosComm::Req req = KokkosComm::isend(Kokkos::DefaultExecutionSpace(), a, dst, 0, MPI_COMM_WORLD);
req.wait();
} else if (1 == rank) {
int src = 0;
MPI_Request req;
KokkosComm::Impl::irecv(a, src, 0, MPI_COMM_WORLD, req);
MPI_Wait(&req, MPI_STATUS_IGNORE);
int errs;
Kokkos::parallel_reduce(
policy, KOKKOS_LAMBDA(int i, int j, int &lsum) { lsum += a(i, j) != Scalar(i * a.extent(0) + j); }, errs);
ASSERT_EQ(errs, 0);
}
}

TYPED_TEST(IsendIrecv, 1D_contig) {
auto a = ViewBuilder<typename TestFixture::Scalar, 1>::view(contig{}, "a", 1013);
test_1d(a);
}

TYPED_TEST(IsendIrecv, 2D_contig) {
auto a = ViewBuilder<typename TestFixture::Scalar, 2>::view(contig{}, "a", 137, 17);
test_2d(a);
}

} // namespace
1 change: 0 additions & 1 deletion unit_tests/test_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ int main(int argc, char *argv[]) {
if (0 == rank) {
std::cerr << argv[0] << " (KokkosComm " << KOKKOSCOMM_VERSION_MAJOR << "." << KOKKOSCOMM_VERSION_MINOR << "."
<< KOKKOSCOMM_VERSION_PATCH << ")\n";
std::cerr << "size=" << size << "\n";
}

Kokkos::initialize();
Expand Down
46 changes: 46 additions & 0 deletions unit_tests/view_builder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#pragma once

#include <Kokkos_Core.hpp>

struct contig {};
struct noncontig {};

template <typename T, int RANK>
struct ViewBuilder;

template <typename T>
struct ViewBuilder<T, 1> {
static auto view(noncontig, const std::string &name, int e0) {
// this is C-style layout, i.e. v(0,0) is next to v(0,1)
Kokkos::View<T **, Kokkos::LayoutRight> v(name, e0, 2);
return Kokkos::subview(v, Kokkos::ALL, 1); // take column 1
}

static auto view(contig, const std::string &name, int e0) { return Kokkos::View<T *>(name, e0); }
};

template <typename T>
struct ViewBuilder<T, 2> {
static auto view(noncontig, const std::string &name, int e0, int e1) {
Kokkos::View<T ***, Kokkos::LayoutRight> v(name, e0, e1, 2);
return Kokkos::subview(v, Kokkos::ALL, Kokkos::ALL, 1);
}

static auto view(contig, const std::string &name, int e0, int e1) { return Kokkos::View<T **>(name, e0, e1); }
};

0 comments on commit aebe2ab

Please sign in to comment.