Skip to content

Commit

Permalink
[dist] remove distributed version of EnableLinOp
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Dec 11, 2024
1 parent f296349 commit bc7c8dc
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 247 deletions.
23 changes: 10 additions & 13 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
std::shared_ptr<const Executor> exec, mpi::communicator comm,
ptr_param<const LinOp> local_matrix_template,
ptr_param<const LinOp> non_local_matrix_template)
: EnableDistributedLinOp<
Matrix<value_type, local_index_type, global_index_type>>{exec},
: EnableLinOp<Matrix>{exec},
DistributedBase{comm},
send_offsets_(comm.size() + 1),
send_sizes_(comm.size()),
Expand All @@ -72,8 +71,7 @@ template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
std::shared_ptr<const Executor> exec, mpi::communicator comm, dim<2> size,
std::shared_ptr<LinOp> local_linop)
: EnableDistributedLinOp<
Matrix<value_type, local_index_type, global_index_type>>{exec},
: EnableLinOp<Matrix>{exec},
DistributedBase{comm},
send_offsets_(comm.size() + 1),
send_sizes_(comm.size()),
Expand All @@ -98,8 +96,7 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
std::vector<comm_index_type> recv_sizes,
std::vector<comm_index_type> recv_offsets,
array<local_index_type> recv_gather_idxs)
: EnableDistributedLinOp<
Matrix<value_type, local_index_type, global_index_type>>{exec},
: EnableLinOp<Matrix>{exec},
DistributedBase{comm},
send_offsets_(comm.size() + 1),
send_sizes_(comm.size()),
Expand Down Expand Up @@ -195,9 +192,9 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::create(
std::vector<comm_index_type> recv_offsets,
array<local_index_type> recv_gather_idxs)
{
return std::unique_ptr<Matrix>{new Matrix{exec, comm, size, local_linop,
non_local_linop, recv_sizes,
recv_offsets, recv_gather_idxs}};
return std::unique_ptr<Matrix>{new Matrix{
exec, comm, size, local_linop, non_local_linop, std::move(recv_sizes),
std::move(recv_offsets), std::move(recv_gather_idxs)}};
}


Expand Down Expand Up @@ -594,8 +591,8 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::row_scale(

template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(const Matrix& other)
: EnableDistributedLinOp<Matrix<value_type, local_index_type,
global_index_type>>{other.get_executor()},
: EnableLinOp<Matrix<value_type, local_index_type,
global_index_type>>{other.get_executor()},
DistributedBase{other.get_communicator()}
{
*this = other;
Expand All @@ -605,8 +602,8 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(const Matrix& other)
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
Matrix&& other) noexcept
: EnableDistributedLinOp<Matrix<value_type, local_index_type,
global_index_type>>{other.get_executor()},
: EnableLinOp<Matrix<value_type, local_index_type,
global_index_type>>{other.get_executor()},
DistributedBase{other.get_communicator()}
{
*this = std::move(other);
Expand Down
30 changes: 13 additions & 17 deletions core/distributed/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ template <typename ValueType>
Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
mpi::communicator comm, dim<2> global_size,
dim<2> local_size, size_type stride)
: EnableDistributedLinOp<Vector<ValueType>>{exec, global_size},
: EnableLinOp<Vector>{exec, global_size},
DistributedBase{comm},
local_{exec, local_size, stride}
{
Expand All @@ -73,7 +73,7 @@ template <typename ValueType>
Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
mpi::communicator comm, dim<2> global_size,
std::unique_ptr<local_vector_type> local_vector)
: EnableDistributedLinOp<Vector<ValueType>>{exec, global_size},
: EnableLinOp<Vector>{exec, global_size},
DistributedBase{comm},
local_{exec}
{
Expand All @@ -85,9 +85,7 @@ template <typename ValueType>
Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
mpi::communicator comm,
std::unique_ptr<local_vector_type> local_vector)
: EnableDistributedLinOp<Vector<ValueType>>{exec, {}},
DistributedBase{comm},
local_{exec}
: EnableLinOp<Vector>{exec, {}}, DistributedBase{comm}, local_{exec}
{
this->set_size(compute_global_size(exec, comm, local_vector->get_size()));
local_vector->move_to(&local_);
Expand Down Expand Up @@ -141,9 +139,9 @@ std::unique_ptr<const Vector<ValueType>> Vector<ValueType>::create_const(
auto non_const_local_vector =
const_cast<local_vector_type*>(local_vector.release());

return std::unique_ptr<const Vector<ValueType>>(new Vector<ValueType>(
std::move(exec), std::move(comm), global_size,
std::unique_ptr<local_vector_type>{non_const_local_vector}));
return std::unique_ptr<const Vector>(
new Vector(std::move(exec), std::move(comm), global_size,
std::unique_ptr<local_vector_type>{non_const_local_vector}));
}


Expand All @@ -154,8 +152,8 @@ std::unique_ptr<const Vector<ValueType>> Vector<ValueType>::create_const(
{
auto global_size =
compute_global_size(exec, comm, local_vector->get_size());
return Vector<ValueType>::create_const(
std::move(exec), std::move(comm), global_size, std::move(local_vector));
return Vector::create_const(std::move(exec), std::move(comm), global_size,
std::move(local_vector));
}


Expand All @@ -173,18 +171,16 @@ std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_config_of(

template <typename ValueType>
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_type_of(
ptr_param<const Vector<ValueType>> other,
std::shared_ptr<const Executor> exec)
ptr_param<const Vector> other, std::shared_ptr<const Executor> exec)
{
return (*other).create_with_type_of_impl(exec, {}, {}, 0);
}


template <typename ValueType>
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_type_of(
ptr_param<const Vector<ValueType>> other,
std::shared_ptr<const Executor> exec, const dim<2>& global_size,
const dim<2>& local_size, size_type stride)
ptr_param<const Vector> other, std::shared_ptr<const Executor> exec,
const dim<2>& global_size, const dim<2>& local_size, size_type stride)
{
return (*other).create_with_type_of_impl(exec, global_size, local_size,
stride);
Expand Down Expand Up @@ -410,7 +406,7 @@ template <typename ValueType>
void Vector<ValueType>::add_scaled(ptr_param<const LinOp> alpha,
ptr_param<const LinOp> b)
{
auto dense_b = as<Vector<ValueType>>(b);
auto dense_b = as<Vector>(b);
local_.add_scaled(alpha, dense_b->get_local_vector());
}

Expand All @@ -419,7 +415,7 @@ template <typename ValueType>
void Vector<ValueType>::sub_scaled(ptr_param<const LinOp> alpha,
ptr_param<const LinOp> b)
{
auto dense_b = as<Vector<ValueType>>(b);
auto dense_b = as<Vector>(b);
local_.sub_scaled(alpha, dense_b->get_local_vector());
}

Expand Down
30 changes: 25 additions & 5 deletions include/ginkgo/core/base/polymorphic_object.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@


namespace gko {
namespace experimental {
namespace distributed {


class DistributedBase;


}
} // namespace experimental


/**
Expand Down Expand Up @@ -647,9 +656,6 @@ std::shared_ptr<const R> copy_and_convert_to(
* ConvertibleTo<ConcreteObject> interface). To enable a default
* implementation of this interface see the EnablePolymorphicAssignment
* mixin.
* @note This mixin can't be used with concrete types that derive from
* experimental::distributed::DistributedBase. In that case use
* experimental::EnableDistributedPolymorphicObject instead.
*
* @tparam ConcreteObject the concrete type which is being implemented
* [CRTP parameter]
Expand All @@ -667,7 +673,14 @@ class EnablePolymorphicObject
std::unique_ptr<PolymorphicObject> create_default_impl(
std::shared_ptr<const Executor> exec) const override
{
return std::unique_ptr<ConcreteObject>{new ConcreteObject(exec)};
if constexpr (std::is_base_of_v<
experimental::distributed::DistributedBase,
ConcreteObject>) {
return std::unique_ptr<ConcreteObject>{
new ConcreteObject(exec, self()->get_communicator())};
} else {
return std::unique_ptr<ConcreteObject>{new ConcreteObject(exec)};
}
}

PolymorphicObject* copy_from_impl(const PolymorphicObject* other) override
Expand Down Expand Up @@ -698,7 +711,14 @@ class EnablePolymorphicObject

PolymorphicObject* clear_impl() override
{
*self() = ConcreteObject{this->get_executor()};
if constexpr (std::is_base_of_v<
experimental::distributed::DistributedBase,
ConcreteObject>) {
*self() = ConcreteObject{this->get_executor(),
self()->get_communicator()};
} else {
*self() = ConcreteObject{this->get_executor()};
}
return this;
}

Expand Down
88 changes: 0 additions & 88 deletions include/ginkgo/core/distributed/lin_op.hpp

This file was deleted.

18 changes: 9 additions & 9 deletions include/ginkgo/core/distributed/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@


#include <ginkgo/core/base/dense_cache.hpp>
#include <ginkgo/core/base/lin_op.hpp>
#include <ginkgo/core/base/mpi.hpp>
#include <ginkgo/core/base/std_extensions.hpp>
#include <ginkgo/core/distributed/base.hpp>
#include <ginkgo/core/distributed/index_map.hpp>
#include <ginkgo/core/distributed/lin_op.hpp>


namespace gko {
Expand Down Expand Up @@ -257,12 +257,12 @@ class Vector;
*/
template <typename ValueType = default_precision,
typename LocalIndexType = int32, typename GlobalIndexType = int64>
class Matrix : public EnableDistributedLinOp<
Matrix<ValueType, LocalIndexType, GlobalIndexType>>,
public ConvertibleTo<Matrix<next_precision_base<ValueType>,
LocalIndexType, GlobalIndexType>>,
public DistributedBase {
friend class EnableDistributedPolymorphicObject<Matrix, LinOp>;
class Matrix
: public EnableLinOp<Matrix<ValueType, LocalIndexType, GlobalIndexType>>,
public ConvertibleTo<Matrix<next_precision_base<ValueType>,
LocalIndexType, GlobalIndexType>>,
public DistributedBase {
friend class EnablePolymorphicObject<Matrix, LinOp>;
friend class Matrix<next_precision_base<ValueType>, LocalIndexType,
GlobalIndexType>;
friend class multigrid::Pgm<ValueType, LocalIndexType>;
Expand All @@ -276,8 +276,8 @@ class Matrix : public EnableDistributedLinOp<
gko::experimental::distributed::Vector<ValueType>;
using local_vector_type = typename global_vector_type::local_vector_type;

using EnableDistributedLinOp<Matrix>::convert_to;
using EnableDistributedLinOp<Matrix>::move_to;
using EnableLinOp<Matrix>::convert_to;
using EnableLinOp<Matrix>::move_to;
using ConvertibleTo<Matrix<next_precision_base<ValueType>, LocalIndexType,
GlobalIndexType>>::convert_to;
using ConvertibleTo<Matrix<next_precision_base<ValueType>, LocalIndexType,
Expand Down
Loading

0 comments on commit bc7c8dc

Please sign in to comment.