Skip to content

Commit

Permalink
move the half custom op
Browse files Browse the repository at this point in the history
Co-authored-by: Marcel Koch <[email protected]>
  • Loading branch information
yhmtsai and MarcelKoch committed Jan 13, 2025
1 parent f50a7d1 commit 68a7045
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 109 deletions.
37 changes: 17 additions & 20 deletions core/distributed/vector.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -8,6 +8,7 @@

#include "core/distributed/vector_kernels.hpp"
#include "core/matrix/dense_kernels.hpp"
#include "core/mpi/mpi_op.hpp"


namespace gko {
Expand Down Expand Up @@ -64,9 +65,7 @@ Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
dim<2> local_size, size_type stride)
: EnableLinOp<Vector>{exec, global_size},
DistributedBase{comm},
local_{exec, local_size, stride},
sum_op_(mpi::sum<ValueType>()),
norm_sum_op_(mpi::sum<remove_complex<ValueType>>())
local_{exec, local_size, stride}
{
GKO_ASSERT_EQUAL_COLS(global_size, local_size);
}
Expand All @@ -77,9 +76,7 @@ Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
std::unique_ptr<local_vector_type> local_vector)
: EnableLinOp<Vector>{exec, global_size},
DistributedBase{comm},
local_{exec},
sum_op_(mpi::sum<ValueType>()),
norm_sum_op_(mpi::sum<remove_complex<ValueType>>())
local_{exec}
{
local_vector->move_to(&local_);
}
Expand All @@ -89,11 +86,7 @@ template <typename ValueType>
Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
mpi::communicator comm,
std::unique_ptr<local_vector_type> local_vector)
: EnableLinOp<Vector>{exec, {}},
DistributedBase{comm},
local_{exec},
sum_op_(mpi::sum<ValueType>()),
norm_sum_op_(mpi::sum<remove_complex<ValueType>>())
: EnableLinOp<Vector>{exec, {}}, DistributedBase{comm}, local_{exec}
{
this->set_size(compute_global_size(exec, comm, local_vector->get_size()));
local_vector->move_to(&local_);
Expand Down Expand Up @@ -470,16 +463,17 @@ void Vector<ValueType>::compute_dot(ptr_param<const LinOp> b,
this->get_local_vector()->compute_dot(as<Vector>(b)->get_local_vector(),
dense_res.get(), tmp);
exec->synchronize();
auto sum_op = gko::experimental::mpi::sum<ValueType>();
if (mpi::requires_host_buffer(exec, comm)) {
host_reduction_buffer_.init(exec->get_master(), dense_res->get_size());
host_reduction_buffer_->copy_from(dense_res.get());
comm.all_reduce(exec->get_master(),
host_reduction_buffer_->get_values(),
static_cast<int>(this->get_size()[1]), sum_op_.get());
static_cast<int>(this->get_size()[1]), sum_op.get());
dense_res->copy_from(host_reduction_buffer_.get());
} else {
comm.all_reduce(exec, dense_res->get_values(),
static_cast<int>(this->get_size()[1]), sum_op_.get());
static_cast<int>(this->get_size()[1]), sum_op.get());
}
}

Expand All @@ -506,16 +500,17 @@ void Vector<ValueType>::compute_conj_dot(ptr_param<const LinOp> b,
this->get_local_vector()->compute_conj_dot(
as<Vector>(b)->get_local_vector(), dense_res.get(), tmp);
exec->synchronize();
auto sum_op = gko::experimental::mpi::sum<ValueType>();
if (mpi::requires_host_buffer(exec, comm)) {
host_reduction_buffer_.init(exec->get_master(), dense_res->get_size());
host_reduction_buffer_->copy_from(dense_res.get());
comm.all_reduce(exec->get_master(),
host_reduction_buffer_->get_values(),
static_cast<int>(this->get_size()[1]), sum_op_.get());
static_cast<int>(this->get_size()[1]), sum_op.get());
dense_res->copy_from(host_reduction_buffer_.get());
} else {
comm.all_reduce(exec, dense_res->get_values(),
static_cast<int>(this->get_size()[1]), sum_op_.get());
static_cast<int>(this->get_size()[1]), sum_op.get());
}
}

Expand Down Expand Up @@ -560,17 +555,18 @@ void Vector<ValueType>::compute_norm1(ptr_param<LinOp> result,
auto dense_res = make_temporary_clone(exec, as<NormVector>(result));
this->get_local_vector()->compute_norm1(dense_res.get());
exec->synchronize();
auto norm_sum_op = gko::experimental::mpi::sum<remove_complex<ValueType>>();
if (mpi::requires_host_buffer(exec, comm)) {
host_norm_buffer_.init(exec->get_master(), dense_res->get_size());
host_norm_buffer_->copy_from(dense_res.get());
comm.all_reduce(exec->get_master(), host_norm_buffer_->get_values(),
static_cast<int>(this->get_size()[1]),
norm_sum_op_.get());
norm_sum_op.get());
dense_res->copy_from(host_norm_buffer_.get());
} else {
comm.all_reduce(exec, dense_res->get_values(),
static_cast<int>(this->get_size()[1]),
norm_sum_op_.get());
norm_sum_op.get());
}
}

Expand All @@ -595,17 +591,18 @@ void Vector<ValueType>::compute_squared_norm2(ptr_param<LinOp> result,
exec->run(vector::make_compute_squared_norm2(this->get_local_vector(),
dense_res.get(), tmp));
exec->synchronize();
auto norm_sum_op = gko::experimental::mpi::sum<remove_complex<ValueType>>();
if (mpi::requires_host_buffer(exec, comm)) {
host_norm_buffer_.init(exec->get_master(), dense_res->get_size());
host_norm_buffer_->copy_from(dense_res.get());
comm.all_reduce(exec->get_master(), host_norm_buffer_->get_values(),
static_cast<int>(this->get_size()[1]),
norm_sum_op_.get());
norm_sum_op.get());
dense_res->copy_from(host_norm_buffer_.get());
} else {
comm.all_reduce(exec, dense_res->get_values(),
static_cast<int>(this->get_size()[1]),
norm_sum_op_.get());
norm_sum_op.get());
}
}

Expand Down
107 changes: 107 additions & 0 deletions core/mpi/mpi_op.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_MPI_MPI_OP_HPP_
#define GKO_CORE_MPI_MPI_OP_HPP_

#include <complex>
#include <type_traits>


#if GINKGO_BUILD_MPI


#include <mpi.h>


namespace gko {
namespace experimental {
namespace mpi {
namespace detail {


template <typename ValueType>
inline void sum(void* input, void* output, int* len, MPI_Datatype* datatype)
{
ValueType* input_ptr = static_cast<ValueType*>(input);
ValueType* output_ptr = static_cast<ValueType*>(output);
for (int i = 0; i < *len; i++) {
output_ptr[i] += input_ptr[i];
}
}

template <typename ValueType>
inline void max(void* input, void* output, int* len, MPI_Datatype* datatype)
{
ValueType* input_ptr = static_cast<ValueType*>(input);
ValueType* output_ptr = static_cast<ValueType*>(output);
for (int i = 0; i < *len; i++) {
if (input_ptr[i] > output_ptr[i]) {
output_ptr[i] = input_ptr[i];
}
}
}

template <typename ValueType>
struct is_mpi_native {
constexpr static bool value =
std::is_arithmetic_v<ValueType> ||
std::is_same_v<ValueType, std::complex<float>> ||
std::is_same_v<ValueType, std::complex<double>>;
};


} // namespace detail


using op_manager = std::shared_ptr<std::pointer_traits<MPI_Op>::element_type>;


template <typename ValueType,
std::enable_if_t<detail::is_mpi_native<ValueType>::value>* = nullptr>
inline op_manager sum()
{
return op_manager([]() { return MPI_SUM; }(), [](MPI_Op op) {});
}

template <typename ValueType,
std::enable_if_t<!detail::is_mpi_native<ValueType>::value>* = nullptr>
inline op_manager sum()
{
return op_manager(
[]() {
MPI_Op operation;
MPI_Op_create(&detail::sum<ValueType>, 1, &operation);
return operation;
}(),
[](MPI_Op op) { MPI_Op_free(&op); });
}


template <typename ValueType,
std::enable_if_t<detail::is_mpi_native<ValueType>::value>* = nullptr>
inline op_manager max()
{
return op_manager([]() { return MPI_MAX; }(), [](MPI_Op op) {});
}

template <typename ValueType,
std::enable_if_t<!detail::is_mpi_native<ValueType>::value>* = nullptr>
inline op_manager max()
{
return op_manager(
[]() {
MPI_Op operation;
MPI_Op_create(&detail::max<ValueType>, 1, &operation);
return operation;
}(),
[](MPI_Op op) { MPI_Op_free(&op); });
}

} // namespace mpi
} // namespace experimental
} // namespace gko

#endif
#endif // GKO_CORE_MPI_MPI_OP_HPP_
1 change: 1 addition & 0 deletions core/solver/gmres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "core/config/solver_config.hpp"
#include "core/distributed/helpers.hpp"
#include "core/mpi/mpi_op.hpp"
#include "core/solver/common_gmres_kernels.hpp"
#include "core/solver/gmres_kernels.hpp"
#include "core/solver/solver_boilerplate.hpp"
Expand Down
87 changes: 2 additions & 85 deletions include/ginkgo/core/base/mpi.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -91,97 +91,14 @@ GKO_REGISTER_MPI_TYPE(double, MPI_DOUBLE);
GKO_REGISTER_MPI_TYPE(long double, MPI_LONG_DOUBLE);
#if GINKGO_ENABLE_HALF
// OpenMPI 5.0 have support from MPIX_C_FLOAT16 and MPICHv3.4a1 MPIX_C_FLOAT16
// TODO: it only works on the transferring
// TODO: adapt it when MPI is configured to support half natively
GKO_REGISTER_MPI_TYPE(half, MPI_UNSIGNED_SHORT);
GKO_REGISTER_MPI_TYPE(std::complex<half>, MPI_FLOAT);
#endif // GKO_ENABLE_HALF
GKO_REGISTER_MPI_TYPE(std::complex<float>, MPI_C_FLOAT_COMPLEX);
GKO_REGISTER_MPI_TYPE(std::complex<double>, MPI_C_DOUBLE_COMPLEX);


namespace detail {


template <typename ValueType>
inline void sum(void* input, void* output, int* len, MPI_Datatype* datatype)
{
ValueType* input_ptr = static_cast<ValueType*>(input);
ValueType* output_ptr = static_cast<ValueType*>(output);
for (int i = 0; i < *len; i++) {
output_ptr[i] += input_ptr[i];
}
}

template <typename ValueType>
inline void max(void* input, void* output, int* len, MPI_Datatype* datatype)
{
ValueType* input_ptr = static_cast<ValueType*>(input);
ValueType* output_ptr = static_cast<ValueType*>(output);
for (int i = 0; i < *len; i++) {
if (input_ptr[i] > output_ptr[i]) {
output_ptr[i] = input_ptr[i];
}
}
}

template <typename ValueType>
struct is_mpi_native {
constexpr static bool value =
std::is_arithmetic_v<ValueType> ||
std::is_same_v<ValueType, std::complex<float>> ||
std::is_same_v<ValueType, std::complex<double>>;
};


} // namespace detail


// using op_manager = std::unique_ptr<std::pointer_traits<MPI_Op>::element_type,
// std::function<void(MPI_Op)>>;
using op_manager = std::shared_ptr<std::pointer_traits<MPI_Op>::element_type>;

template <typename ValueType,
std::enable_if_t<detail::is_mpi_native<ValueType>::value>* = nullptr>
inline op_manager sum()
{
return op_manager([]() { return MPI_SUM; }(), [](MPI_Op op) {});
}

template <typename ValueType,
std::enable_if_t<!detail::is_mpi_native<ValueType>::value>* = nullptr>
inline op_manager sum()
{
return op_manager(
[]() {
MPI_Op operation;
MPI_Op_create(&detail::sum<ValueType>, 1, &operation);
return operation;
}(),
[](MPI_Op op) { MPI_Op_free(&op); });
}


template <typename ValueType,
std::enable_if_t<detail::is_mpi_native<ValueType>::value>* = nullptr>
inline op_manager max()
{
return op_manager([]() { return MPI_MAX; }(), [](MPI_Op op) {});
}

template <typename ValueType,
std::enable_if_t<!detail::is_mpi_native<ValueType>::value>* = nullptr>
inline op_manager max()
{
return op_manager(
[]() {
MPI_Op operation;
MPI_Op_create(&detail::max<ValueType>, 1, &operation);
return operation;
}(),
[](MPI_Op op) { MPI_Op_free(&op); });
}


/**
* A move-only wrapper for a contiguous MPI_Datatype.
*
Expand Down
6 changes: 2 additions & 4 deletions include/ginkgo/core/distributed/vector.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -662,8 +662,6 @@ class Vector
local_vector_type local_;
::gko::detail::DenseCache<ValueType> host_reduction_buffer_;
::gko::detail::DenseCache<remove_complex<ValueType>> host_norm_buffer_;
mpi::op_manager sum_op_;
mpi::op_manager norm_sum_op_;
};


Expand Down Expand Up @@ -700,7 +698,7 @@ struct conversion_target_helper<experimental::distributed::Vector<ValueType>> {
}

// Allow to create_empty of the same type
// For distributed case, next<next<V>> will be V in the candicated list.
// For distributed case, next<next<V>> will be V in the candidate list.
// TODO: decide to whether to add this or add condition to the list
static std::unique_ptr<target_type> create_empty(const target_type* source)
{
Expand Down

0 comments on commit 68a7045

Please sign in to comment.