diff --git a/core/distributed/vector.cpp b/core/distributed/vector.cpp index 732f8d5b3ef..2e17a5058e9 100644 --- a/core/distributed/vector.cpp +++ b/core/distributed/vector.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -8,6 +8,7 @@ #include "core/distributed/vector_kernels.hpp" #include "core/matrix/dense_kernels.hpp" +#include "core/mpi/mpi_op.hpp" namespace gko { @@ -64,9 +65,7 @@ Vector::Vector(std::shared_ptr exec, dim<2> local_size, size_type stride) : EnableLinOp{exec, global_size}, DistributedBase{comm}, - local_{exec, local_size, stride}, - sum_op_(mpi::sum()), - norm_sum_op_(mpi::sum>()) + local_{exec, local_size, stride} { GKO_ASSERT_EQUAL_COLS(global_size, local_size); } @@ -77,9 +76,7 @@ Vector::Vector(std::shared_ptr exec, std::unique_ptr local_vector) : EnableLinOp{exec, global_size}, DistributedBase{comm}, - local_{exec}, - sum_op_(mpi::sum()), - norm_sum_op_(mpi::sum>()) + local_{exec} { local_vector->move_to(&local_); } @@ -89,11 +86,7 @@ template Vector::Vector(std::shared_ptr exec, mpi::communicator comm, std::unique_ptr local_vector) - : EnableLinOp{exec, {}}, - DistributedBase{comm}, - local_{exec}, - sum_op_(mpi::sum()), - norm_sum_op_(mpi::sum>()) + : EnableLinOp{exec, {}}, DistributedBase{comm}, local_{exec} { this->set_size(compute_global_size(exec, comm, local_vector->get_size())); local_vector->move_to(&local_); @@ -470,16 +463,17 @@ void Vector::compute_dot(ptr_param b, this->get_local_vector()->compute_dot(as(b)->get_local_vector(), dense_res.get(), tmp); exec->synchronize(); + auto sum_op = gko::experimental::mpi::sum(); if (mpi::requires_host_buffer(exec, comm)) { host_reduction_buffer_.init(exec->get_master(), dense_res->get_size()); host_reduction_buffer_->copy_from(dense_res.get()); comm.all_reduce(exec->get_master(), host_reduction_buffer_->get_values(), - static_cast(this->get_size()[1]), sum_op_.get()); + static_cast(this->get_size()[1]), sum_op.get()); dense_res->copy_from(host_reduction_buffer_.get()); } else { comm.all_reduce(exec, dense_res->get_values(), - static_cast(this->get_size()[1]), sum_op_.get()); + static_cast(this->get_size()[1]), sum_op.get()); } } @@ -506,16 +500,17 @@ void Vector::compute_conj_dot(ptr_param b, this->get_local_vector()->compute_conj_dot( as(b)->get_local_vector(), dense_res.get(), tmp); exec->synchronize(); + auto sum_op = gko::experimental::mpi::sum(); if (mpi::requires_host_buffer(exec, comm)) { host_reduction_buffer_.init(exec->get_master(), dense_res->get_size()); host_reduction_buffer_->copy_from(dense_res.get()); comm.all_reduce(exec->get_master(), host_reduction_buffer_->get_values(), - static_cast(this->get_size()[1]), sum_op_.get()); + static_cast(this->get_size()[1]), sum_op.get()); dense_res->copy_from(host_reduction_buffer_.get()); } else { comm.all_reduce(exec, dense_res->get_values(), - static_cast(this->get_size()[1]), sum_op_.get()); + static_cast(this->get_size()[1]), sum_op.get()); } } @@ -560,17 +555,18 @@ void Vector::compute_norm1(ptr_param result, auto dense_res = make_temporary_clone(exec, as(result)); this->get_local_vector()->compute_norm1(dense_res.get()); exec->synchronize(); + auto norm_sum_op = gko::experimental::mpi::sum>(); if (mpi::requires_host_buffer(exec, comm)) { host_norm_buffer_.init(exec->get_master(), dense_res->get_size()); host_norm_buffer_->copy_from(dense_res.get()); comm.all_reduce(exec->get_master(), host_norm_buffer_->get_values(), static_cast(this->get_size()[1]), - norm_sum_op_.get()); + norm_sum_op.get()); dense_res->copy_from(host_norm_buffer_.get()); } else { comm.all_reduce(exec, dense_res->get_values(), static_cast(this->get_size()[1]), - norm_sum_op_.get()); + norm_sum_op.get()); } } @@ -595,17 +591,18 @@ void Vector::compute_squared_norm2(ptr_param result, exec->run(vector::make_compute_squared_norm2(this->get_local_vector(), dense_res.get(), tmp)); exec->synchronize(); + auto norm_sum_op = gko::experimental::mpi::sum>(); if (mpi::requires_host_buffer(exec, comm)) { host_norm_buffer_.init(exec->get_master(), dense_res->get_size()); host_norm_buffer_->copy_from(dense_res.get()); comm.all_reduce(exec->get_master(), host_norm_buffer_->get_values(), static_cast(this->get_size()[1]), - norm_sum_op_.get()); + norm_sum_op.get()); dense_res->copy_from(host_norm_buffer_.get()); } else { comm.all_reduce(exec, dense_res->get_values(), static_cast(this->get_size()[1]), - norm_sum_op_.get()); + norm_sum_op.get()); } } diff --git a/core/mpi/mpi_op.hpp b/core/mpi/mpi_op.hpp new file mode 100644 index 00000000000..af1476542c3 --- /dev/null +++ b/core/mpi/mpi_op.hpp @@ -0,0 +1,107 @@ +// SPDX-FileCopyrightText: 2025 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_CORE_MPI_MPI_OP_HPP_ +#define GKO_CORE_MPI_MPI_OP_HPP_ + +#include +#include + + +#if GINKGO_BUILD_MPI + + +#include + + +namespace gko { +namespace experimental { +namespace mpi { +namespace detail { + + +template +inline void sum(void* input, void* output, int* len, MPI_Datatype* datatype) +{ + ValueType* input_ptr = static_cast(input); + ValueType* output_ptr = static_cast(output); + for (int i = 0; i < *len; i++) { + output_ptr[i] += input_ptr[i]; + } +} + +template +inline void max(void* input, void* output, int* len, MPI_Datatype* datatype) +{ + ValueType* input_ptr = static_cast(input); + ValueType* output_ptr = static_cast(output); + for (int i = 0; i < *len; i++) { + if (input_ptr[i] > output_ptr[i]) { + output_ptr[i] = input_ptr[i]; + } + } +} + +template +struct is_mpi_native { + constexpr static bool value = + std::is_arithmetic_v || + std::is_same_v> || + std::is_same_v>; +}; + + +} // namespace detail + + +using op_manager = std::shared_ptr::element_type>; + + +template ::value>* = nullptr> +inline op_manager sum() +{ + return op_manager([]() { return MPI_SUM; }(), [](MPI_Op op) {}); +} + +template ::value>* = nullptr> +inline op_manager sum() +{ + return op_manager( + []() { + MPI_Op operation; + MPI_Op_create(&detail::sum, 1, &operation); + return operation; + }(), + [](MPI_Op op) { MPI_Op_free(&op); }); +} + + +template ::value>* = nullptr> +inline op_manager max() +{ + return op_manager([]() { return MPI_MAX; }(), [](MPI_Op op) {}); +} + +template ::value>* = nullptr> +inline op_manager max() +{ + return op_manager( + []() { + MPI_Op operation; + MPI_Op_create(&detail::max, 1, &operation); + return operation; + }(), + [](MPI_Op op) { MPI_Op_free(&op); }); +} + +} // namespace mpi +} // namespace experimental +} // namespace gko + +#endif +#endif // GKO_CORE_MPI_MPI_OP_HPP_ diff --git a/core/solver/gmres.cpp b/core/solver/gmres.cpp index 067d7d7aad2..10c54cd6715 100644 --- a/core/solver/gmres.cpp +++ b/core/solver/gmres.cpp @@ -17,6 +17,7 @@ #include "core/config/solver_config.hpp" #include "core/distributed/helpers.hpp" +#include "core/mpi/mpi_op.hpp" #include "core/solver/common_gmres_kernels.hpp" #include "core/solver/gmres_kernels.hpp" #include "core/solver/solver_boilerplate.hpp" diff --git a/include/ginkgo/core/base/mpi.hpp b/include/ginkgo/core/base/mpi.hpp index 555ab7099b1..abaad35c816 100644 --- a/include/ginkgo/core/base/mpi.hpp +++ b/include/ginkgo/core/base/mpi.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -91,7 +91,7 @@ GKO_REGISTER_MPI_TYPE(double, MPI_DOUBLE); GKO_REGISTER_MPI_TYPE(long double, MPI_LONG_DOUBLE); #if GINKGO_ENABLE_HALF // OpenMPI 5.0 have support from MPIX_C_FLOAT16 and MPICHv3.4a1 MPIX_C_FLOAT16 -// TODO: it only works on the transferring +// TODO: adapt it when MPI is configured to support half natively GKO_REGISTER_MPI_TYPE(half, MPI_UNSIGNED_SHORT); GKO_REGISTER_MPI_TYPE(std::complex, MPI_FLOAT); #endif // GKO_ENABLE_HALF @@ -99,89 +99,6 @@ GKO_REGISTER_MPI_TYPE(std::complex, MPI_C_FLOAT_COMPLEX); GKO_REGISTER_MPI_TYPE(std::complex, MPI_C_DOUBLE_COMPLEX); -namespace detail { - - -template -inline void sum(void* input, void* output, int* len, MPI_Datatype* datatype) -{ - ValueType* input_ptr = static_cast(input); - ValueType* output_ptr = static_cast(output); - for (int i = 0; i < *len; i++) { - output_ptr[i] += input_ptr[i]; - } -} - -template -inline void max(void* input, void* output, int* len, MPI_Datatype* datatype) -{ - ValueType* input_ptr = static_cast(input); - ValueType* output_ptr = static_cast(output); - for (int i = 0; i < *len; i++) { - if (input_ptr[i] > output_ptr[i]) { - output_ptr[i] = input_ptr[i]; - } - } -} - -template -struct is_mpi_native { - constexpr static bool value = - std::is_arithmetic_v || - std::is_same_v> || - std::is_same_v>; -}; - - -} // namespace detail - - -// using op_manager = std::unique_ptr::element_type, -// std::function>; -using op_manager = std::shared_ptr::element_type>; - -template ::value>* = nullptr> -inline op_manager sum() -{ - return op_manager([]() { return MPI_SUM; }(), [](MPI_Op op) {}); -} - -template ::value>* = nullptr> -inline op_manager sum() -{ - return op_manager( - []() { - MPI_Op operation; - MPI_Op_create(&detail::sum, 1, &operation); - return operation; - }(), - [](MPI_Op op) { MPI_Op_free(&op); }); -} - - -template ::value>* = nullptr> -inline op_manager max() -{ - return op_manager([]() { return MPI_MAX; }(), [](MPI_Op op) {}); -} - -template ::value>* = nullptr> -inline op_manager max() -{ - return op_manager( - []() { - MPI_Op operation; - MPI_Op_create(&detail::max, 1, &operation); - return operation; - }(), - [](MPI_Op op) { MPI_Op_free(&op); }); -} - - /** * A move-only wrapper for a contiguous MPI_Datatype. * diff --git a/include/ginkgo/core/distributed/vector.hpp b/include/ginkgo/core/distributed/vector.hpp index 181a7de3460..dc278459fd4 100644 --- a/include/ginkgo/core/distributed/vector.hpp +++ b/include/ginkgo/core/distributed/vector.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -662,8 +662,6 @@ class Vector local_vector_type local_; ::gko::detail::DenseCache host_reduction_buffer_; ::gko::detail::DenseCache> host_norm_buffer_; - mpi::op_manager sum_op_; - mpi::op_manager norm_sum_op_; }; @@ -700,7 +698,7 @@ struct conversion_target_helper> { } // Allow to create_empty of the same type - // For distributed case, next> will be V in the candicated list. + // For distributed case, next> will be V in the candidate list. // TODO: decide to whether to add this or add condition to the list static std::unique_ptr create_empty(const target_type* source) {