Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpi: contiguous Irecv wrapper #55

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/KokkosComm_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct Traits {
/*! \brief This can be specialized to do custom behavior for a particular view*/
template <KokkosView View>
struct Traits<View> {
// product of extents is span
static bool is_contiguous(const View &v) { return v.span_is_contiguous(); }

static auto data_handle(const View &v) { return v.data(); }
Expand Down
47 changes: 47 additions & 0 deletions src/impl/KokkosComm_irecv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#pragma once

#include <memory>

#include <Kokkos_Core.hpp>

#include "KokkosComm_pack_traits.hpp"
#include "KokkosComm_traits.hpp"

// impl
#include "KokkosComm_include_mpi.hpp"

namespace KokkosComm::Impl {

// low-level API
template <KokkosView RecvView>
void irecv(RecvView &rv, int src, int tag, MPI_Comm comm, MPI_Request &req) {
cedricchevalier19 marked this conversation as resolved.
Show resolved Hide resolved
Kokkos::Tools::pushRegion("KokkosComm::Impl::irecv");

using KCT = KokkosComm::Traits<RecvView>;

if (KCT::is_contiguous(rv)) {
using RecvScalar = typename RecvView::value_type;
MPI_Irecv(KCT::data_handle(rv), KCT::span(rv), mpi_type_v<RecvScalar>, src, tag, comm, &req);
} else {
throw std::runtime_error("Only contiguous irecv viewsupported");
}

Kokkos::Tools::popRegion();
}
} // namespace KokkosComm::Impl
1 change: 1 addition & 0 deletions unit_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ target_link_libraries(test-mpi MPI::MPI_CXX)
# Kokkos Comm tests
add_executable(test-main test_main.cpp
test_gtest_mpi.cpp
test_isendirecv.cpp
test_isendrecv.cpp
test_reduce.cpp
test_sendrecv.cpp
Expand Down
109 changes: 109 additions & 0 deletions unit_tests/test_isendirecv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#include <gtest/gtest.h>

#include "KokkosComm.hpp"
#include "KokkosComm_irecv.hpp"

#include "view_builder.hpp"

namespace {

template <typename T>
class IsendIrecv : public testing::Test {
public:
using Scalar = T;
};

using ScalarTypes =
::testing::Types<float, double, Kokkos::complex<float>, Kokkos::complex<double>, int, unsigned, int64_t, size_t>;
TYPED_TEST_SUITE(IsendIrecv, ScalarTypes);

template <KokkosComm::KokkosView View1D>
void test_1d(const View1D &a) {
static_assert(View1D::rank == 1, "");
using Scalar = typename View1D::non_const_value_type;

int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
if (size < 2) {
GTEST_SKIP() << "Requires >= 2 ranks (" << size << " provided)";
}

if (0 == rank) {
int dst = 1;
Kokkos::parallel_for(
a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; });
KokkosComm::Req req = KokkosComm::isend(Kokkos::DefaultExecutionSpace(), a, dst, 0, MPI_COMM_WORLD);
req.wait();
cedricchevalier19 marked this conversation as resolved.
Show resolved Hide resolved
} else if (1 == rank) {
int src = 0;
MPI_Request req;
KokkosComm::Impl::irecv(a, src, 0, MPI_COMM_WORLD, req);
MPI_Wait(&req, MPI_STATUS_IGNORE);
cwpearson marked this conversation as resolved.
Show resolved Hide resolved
int errs;
Kokkos::parallel_reduce(
a.extent(0), KOKKOS_LAMBDA(const int &i, int &lsum) { lsum += a(i) != Scalar(i); }, errs);
ASSERT_EQ(errs, 0);
}
}

template <KokkosComm::KokkosView View2D>
void test_2d(const View2D &a) {
static_assert(View2D::rank == 2, "");
using Scalar = typename View2D::non_const_value_type;

int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
if (size < 2) {
GTEST_SKIP() << "Requires >= 2 ranks (" << size << " provided)";
}

using Policy = Kokkos::MDRangePolicy<Kokkos::Rank<2>>;
Policy policy({0, 0}, {a.extent(0), a.extent(1)});

if (0 == rank) {
int dst = 1;
Kokkos::parallel_for(
policy, KOKKOS_LAMBDA(int i, int j) { a(i, j) = i * a.extent(0) + j; });
KokkosComm::Req req = KokkosComm::isend(Kokkos::DefaultExecutionSpace(), a, dst, 0, MPI_COMM_WORLD);
req.wait();
} else if (1 == rank) {
int src = 0;
MPI_Request req;
KokkosComm::Impl::irecv(a, src, 0, MPI_COMM_WORLD, req);
MPI_Wait(&req, MPI_STATUS_IGNORE);
int errs;
Kokkos::parallel_reduce(
policy, KOKKOS_LAMBDA(int i, int j, int &lsum) { lsum += a(i, j) != Scalar(i * a.extent(0) + j); }, errs);
ASSERT_EQ(errs, 0);
}
}

TYPED_TEST(IsendIrecv, 1D_contig) {
auto a = ViewBuilder<typename TestFixture::Scalar, 1>::view(contig{}, "a", 1013);
test_1d(a);
}

TYPED_TEST(IsendIrecv, 2D_contig) {
auto a = ViewBuilder<typename TestFixture::Scalar, 2>::view(contig{}, "a", 137, 17);
test_2d(a);
}

} // namespace
1 change: 0 additions & 1 deletion unit_tests/test_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ int main(int argc, char *argv[]) {
if (0 == rank) {
std::cerr << argv[0] << " (KokkosComm " << KOKKOSCOMM_VERSION_MAJOR << "." << KOKKOSCOMM_VERSION_MINOR << "."
<< KOKKOSCOMM_VERSION_PATCH << ")\n";
std::cerr << "size=" << size << "\n";
}

Kokkos::initialize();
Expand Down
46 changes: 46 additions & 0 deletions unit_tests/view_builder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#pragma once

#include <Kokkos_Core.hpp>

struct contig {};
struct noncontig {};

template <typename T, int RANK>
struct ViewBuilder;

template <typename T>
struct ViewBuilder<T, 1> {
static auto view(noncontig, const std::string &name, int e0) {
// this is C-style layout, i.e. v(0,0) is next to v(0,1)
Kokkos::View<T **, Kokkos::LayoutRight> v(name, e0, 2);
return Kokkos::subview(v, Kokkos::ALL, 1); // take column 1
cedricchevalier19 marked this conversation as resolved.
Show resolved Hide resolved
}

static auto view(contig, const std::string &name, int e0) { return Kokkos::View<T *>(name, e0); }
};

template <typename T>
struct ViewBuilder<T, 2> {
static auto view(noncontig, const std::string &name, int e0, int e1) {
Kokkos::View<T ***, Kokkos::LayoutRight> v(name, e0, e1, 2);
return Kokkos::subview(v, Kokkos::ALL, Kokkos::ALL, 1);
}

static auto view(contig, const std::string &name, int e0, int e1) { return Kokkos::View<T **>(name, e0, e1); }
};