Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable Half in mpi #1759

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/cuda_hip/distributed/assembly_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ void count_non_owning_entries(
num_parts, local_part, row_part_ptrs.get_data(), send_count.get_data());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_COUNT_NON_OWNING_ENTRIES);


Expand Down
16 changes: 8 additions & 8 deletions common/cuda_hip/distributed/matrix_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ void separate_local_nonlocal(
col_range_starting_indices[range_id];
};

using input_type = input_type<ValueType, GlobalIndexType>;
using input_type = input_type<device_type<ValueType>, GlobalIndexType>;
auto input_it = thrust::make_zip_iterator(thrust::make_tuple(
input.get_const_row_idxs(), input.get_const_col_idxs(),
input.get_const_values(), row_range_ids.get_const_data(),
col_range_ids.get_const_data()));
as_device_type(input.get_const_values()),
row_range_ids.get_const_data(), col_range_ids.get_const_data()));

// copy and transform local entries into arrays
local_row_idxs.resize_and_reset(num_local_elements);
Expand All @@ -157,9 +157,9 @@ void separate_local_nonlocal(
thrust::copy_if(
policy, local_it, local_it + input.get_num_stored_elements(),
range_ids_it,
thrust::make_zip_iterator(thrust::make_tuple(local_row_idxs.get_data(),
local_col_idxs.get_data(),
local_values.get_data())),
thrust::make_zip_iterator(thrust::make_tuple(
local_row_idxs.get_data(), local_col_idxs.get_data(),
as_device_type(local_values.get_data()))),
[local_part, row_part_ids, col_part_ids] __host__ __device__(
const thrust::tuple<size_type, size_type>& tuple) {
auto row_part = row_part_ids[thrust::get<0>(tuple)];
Expand All @@ -185,7 +185,7 @@ void separate_local_nonlocal(
range_ids_it,
thrust::make_zip_iterator(thrust::make_tuple(
non_local_row_idxs.get_data(), non_local_col_idxs.get_data(),
non_local_values.get_data())),
as_device_type(non_local_values.get_data()))),
[local_part, row_part_ids, col_part_ids] __host__ __device__(
const thrust::tuple<size_type, size_type>& tuple) {
auto row_part = row_part_ids[thrust::get<0>(tuple)];
Expand All @@ -194,7 +194,7 @@ void separate_local_nonlocal(
});
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_SEPARATE_LOCAL_NONLOCAL);


Expand Down
2 changes: 1 addition & 1 deletion common/cuda_hip/distributed/vector_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ void build_local(
range_id.get_data(), local_mtx->get_values(), is_local_row);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_DISTRIBUTED_VECTOR_BUILD_LOCAL);


Expand Down
2 changes: 1 addition & 1 deletion common/unified/distributed/assembly_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void fill_send_buffers(
send_values.get_data());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_FILL_SEND_BUFFERS);


Expand Down
2 changes: 1 addition & 1 deletion core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@
typename GlobalIndexType> \
_macro(ValueType, LocalIndexType, GlobalIndexType) \
GKO_NOT_COMPILED(GKO_HOOK_MODULE); \
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(_macro)
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(_macro)

#define GKO_STUB_TEMPLATE_TYPE_BASE(_macro) \
template <typename IndexType> \
Expand Down
2 changes: 1 addition & 1 deletion core/distributed/assembly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ device_matrix_data<ValueType, GlobalIndexType> assemble_rows_from_neighbors(
mpi::communicator comm, \
const device_matrix_data<_value_type, _global_type>& input, \
ptr_param<const Partition<_local_type, _global_type>> partition)
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_ASSEMBLE_ROWS_FROM_NEIGHBORS);


Expand Down
14 changes: 5 additions & 9 deletions core/distributed/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,11 @@ void vector_dispatch(T* linop, F&& f, Args&&... args)
{
#if GINKGO_BUILD_MPI
if (is_distributed(linop)) {
if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
GKO_NOT_SUPPORTED(linop);
} else {
using type = std::conditional_t<
std::is_const<T>::value,
const experimental::distributed::Vector<ValueType>,
experimental::distributed::Vector<ValueType>>;
f(dynamic_cast<type*>(linop), std::forward<Args>(args)...);
}
using type = std::conditional_t<
std::is_const<T>::value,
const experimental::distributed::Vector<ValueType>,
experimental::distributed::Vector<ValueType>>;
f(dynamic_cast<type*>(linop), std::forward<Args>(args)...);
} else
#endif
{
Expand Down
50 changes: 45 additions & 5 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::create(

template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(
Matrix<next_precision_base<value_type>, local_index_type,
global_index_type>* result) const
Matrix<next_precision<value_type>, local_index_type, global_index_type>*
result) const
{
GKO_ASSERT(this->get_communicator().size() ==
result->get_communicator().size());
Expand All @@ -219,8 +219,8 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(

template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
Matrix<next_precision_base<value_type>, local_index_type,
global_index_type>* result)
Matrix<next_precision<value_type>, local_index_type, global_index_type>*
result)
{
GKO_ASSERT(this->get_communicator().size() ==
result->get_communicator().size());
Expand All @@ -237,6 +237,46 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
}


#if GINKGO_ENABLE_HALF
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(
Matrix<next_precision<next_precision<value_type>>, local_index_type,
global_index_type>* result) const
{
GKO_ASSERT(this->get_communicator().size() ==
result->get_communicator().size());
result->local_mtx_->copy_from(this->local_mtx_.get());
result->non_local_mtx_->copy_from(this->non_local_mtx_.get());
result->gather_idxs_ = this->gather_idxs_;
result->send_offsets_ = this->send_offsets_;
result->recv_offsets_ = this->recv_offsets_;
result->recv_sizes_ = this->recv_sizes_;
result->send_sizes_ = this->send_sizes_;
result->non_local_to_global_ = this->non_local_to_global_;
result->set_size(this->get_size());
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
Matrix<next_precision<next_precision<value_type>>, local_index_type,
global_index_type>* result)
{
GKO_ASSERT(this->get_communicator().size() ==
result->get_communicator().size());
result->local_mtx_->move_from(this->local_mtx_.get());
result->non_local_mtx_->move_from(this->non_local_mtx_.get());
result->gather_idxs_ = std::move(this->gather_idxs_);
result->send_offsets_ = std::move(this->send_offsets_);
result->recv_offsets_ = std::move(this->recv_offsets_);
result->recv_sizes_ = std::move(this->recv_sizes_);
result->send_sizes_ = std::move(this->send_sizes_);
result->non_local_to_global_ = std::move(this->non_local_to_global_);
result->set_size(this->get_size());
this->set_size({});
}
#endif

template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
const device_matrix_data<value_type, global_index_type>& data,
Expand Down Expand Up @@ -661,7 +701,7 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::operator=(Matrix&& other)
#define GKO_DECLARE_DISTRIBUTED_MATRIX(ValueType, LocalIndexType, \
GlobalIndexType) \
class Matrix<ValueType, LocalIndexType, GlobalIndexType>
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_DISTRIBUTED_MATRIX);


Expand Down
3 changes: 1 addition & 2 deletions core/distributed/preconditioner/schwarz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::generate(

#define GKO_DECLARE_SCHWARZ(ValueType, LocalIndexType, GlobalIndexType) \
class Schwarz<ValueType, LocalIndexType, GlobalIndexType>
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_DECLARE_SCHWARZ);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(GKO_DECLARE_SCHWARZ);


} // namespace preconditioner
Expand Down
64 changes: 48 additions & 16 deletions core/distributed/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
dim<2> local_size, size_type stride)
: EnableLinOp<Vector>{exec, global_size},
DistributedBase{comm},
local_{exec, local_size, stride}
local_{exec, local_size, stride},
sum_op_(mpi::sum<ValueType>()),
norm_sum_op_(mpi::sum<remove_complex<ValueType>>())
{
GKO_ASSERT_EQUAL_COLS(global_size, local_size);
}
Expand All @@ -75,7 +77,9 @@ Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
std::unique_ptr<local_vector_type> local_vector)
: EnableLinOp<Vector>{exec, global_size},
DistributedBase{comm},
local_{exec}
local_{exec},
sum_op_(mpi::sum<ValueType>()),
norm_sum_op_(mpi::sum<remove_complex<ValueType>>())
{
local_vector->move_to(&local_);
}
Expand All @@ -85,7 +89,11 @@ template <typename ValueType>
Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
mpi::communicator comm,
std::unique_ptr<local_vector_type> local_vector)
: EnableLinOp<Vector>{exec, {}}, DistributedBase{comm}, local_{exec}
: EnableLinOp<Vector>{exec, {}},
DistributedBase{comm},
local_{exec},
sum_op_(mpi::sum<ValueType>()),
norm_sum_op_(mpi::sum<remove_complex<ValueType>>())
{
this->set_size(compute_global_size(exec, comm, local_vector->get_size()));
local_vector->move_to(&local_);
Expand Down Expand Up @@ -279,7 +287,7 @@ void Vector<ValueType>::fill(const ValueType value)

template <typename ValueType>
void Vector<ValueType>::convert_to(
Vector<next_precision_base<ValueType>>* result) const
Vector<next_precision<ValueType>>* result) const
{
GKO_ASSERT(this->get_communicator().size() ==
result->get_communicator().size());
Expand All @@ -289,12 +297,32 @@ void Vector<ValueType>::convert_to(


template <typename ValueType>
void Vector<ValueType>::move_to(Vector<next_precision_base<ValueType>>* result)
void Vector<ValueType>::move_to(Vector<next_precision<ValueType>>* result)
{
this->convert_to(result);
}


#if GINKGO_ENABLE_HALF
template <typename ValueType>
void Vector<ValueType>::convert_to(
Vector<next_precision<next_precision<ValueType>>>* result) const
{
GKO_ASSERT(this->get_communicator().size() ==
result->get_communicator().size());
result->set_size(this->get_size());
this->get_local_vector()->convert_to(&result->local_);
}


template <typename ValueType>
void Vector<ValueType>::move_to(
Vector<next_precision<next_precision<ValueType>>>* result)
{
this->convert_to(result);
}
#endif

template <typename ValueType>
std::unique_ptr<typename Vector<ValueType>::absolute_type>
Vector<ValueType>::compute_absolute() const
Expand Down Expand Up @@ -447,11 +475,11 @@ void Vector<ValueType>::compute_dot(ptr_param<const LinOp> b,
host_reduction_buffer_->copy_from(dense_res.get());
comm.all_reduce(exec->get_master(),
host_reduction_buffer_->get_values(),
static_cast<int>(this->get_size()[1]), MPI_SUM);
static_cast<int>(this->get_size()[1]), sum_op_.get());
dense_res->copy_from(host_reduction_buffer_.get());
} else {
comm.all_reduce(exec, dense_res->get_values(),
static_cast<int>(this->get_size()[1]), MPI_SUM);
static_cast<int>(this->get_size()[1]), sum_op_.get());
}
}

Expand Down Expand Up @@ -483,11 +511,11 @@ void Vector<ValueType>::compute_conj_dot(ptr_param<const LinOp> b,
host_reduction_buffer_->copy_from(dense_res.get());
comm.all_reduce(exec->get_master(),
host_reduction_buffer_->get_values(),
static_cast<int>(this->get_size()[1]), MPI_SUM);
static_cast<int>(this->get_size()[1]), sum_op_.get());
dense_res->copy_from(host_reduction_buffer_.get());
} else {
comm.all_reduce(exec, dense_res->get_values(),
static_cast<int>(this->get_size()[1]), MPI_SUM);
static_cast<int>(this->get_size()[1]), sum_op_.get());
}
}

Expand Down Expand Up @@ -536,11 +564,13 @@ void Vector<ValueType>::compute_norm1(ptr_param<LinOp> result,
host_norm_buffer_.init(exec->get_master(), dense_res->get_size());
host_norm_buffer_->copy_from(dense_res.get());
comm.all_reduce(exec->get_master(), host_norm_buffer_->get_values(),
static_cast<int>(this->get_size()[1]), MPI_SUM);
static_cast<int>(this->get_size()[1]),
norm_sum_op_.get());
dense_res->copy_from(host_norm_buffer_.get());
} else {
comm.all_reduce(exec, dense_res->get_values(),
static_cast<int>(this->get_size()[1]), MPI_SUM);
static_cast<int>(this->get_size()[1]),
norm_sum_op_.get());
}
}

Expand Down Expand Up @@ -569,11 +599,13 @@ void Vector<ValueType>::compute_squared_norm2(ptr_param<LinOp> result,
host_norm_buffer_.init(exec->get_master(), dense_res->get_size());
host_norm_buffer_->copy_from(dense_res.get());
comm.all_reduce(exec->get_master(), host_norm_buffer_->get_values(),
static_cast<int>(this->get_size()[1]), MPI_SUM);
static_cast<int>(this->get_size()[1]),
norm_sum_op_.get());
dense_res->copy_from(host_norm_buffer_.get());
} else {
comm.all_reduce(exec, dense_res->get_values(),
static_cast<int>(this->get_size()[1]), MPI_SUM);
static_cast<int>(this->get_size()[1]),
norm_sum_op_.get());
}
}

Expand Down Expand Up @@ -612,10 +644,10 @@ void Vector<ValueType>::compute_mean(ptr_param<LinOp> result,
host_reduction_buffer_->copy_from(dense_res.get());
comm.all_reduce(exec->get_master(),
host_reduction_buffer_->get_values(), num_vecs,
MPI_SUM);
sum_op_.get());
dense_res->copy_from(host_reduction_buffer_.get());
} else {
comm.all_reduce(exec, dense_res->get_values(), num_vecs, MPI_SUM);
comm.all_reduce(exec, dense_res->get_values(), num_vecs, sum_op_.get());
}
}

Expand Down Expand Up @@ -720,7 +752,7 @@ std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_type_of_impl(


#define GKO_DECLARE_DISTRIBUTED_VECTOR(ValueType) class Vector<ValueType>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_BASE(GKO_DECLARE_DISTRIBUTED_VECTOR);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DISTRIBUTED_VECTOR);


} // namespace distributed
Expand Down
4 changes: 2 additions & 2 deletions core/distributed/vector_cache.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -48,7 +48,7 @@ void VectorCache<ValueType>::init_from(


#define GKO_DECLARE_VECTOR_CACHE(_type) class VectorCache<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_BASE(GKO_DECLARE_VECTOR_CACHE);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_VECTOR_CACHE);


} // namespace detail
Expand Down
Loading
Loading