Skip to content

Commit

Permalink
use deferred factory and update format
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Dec 1, 2023
1 parent cd0033c commit 70c9452
Show file tree
Hide file tree
Showing 27 changed files with 433 additions and 496 deletions.
385 changes: 42 additions & 343 deletions core/config/solver_config.cpp

Large diffs are not rendered by default.

48 changes: 8 additions & 40 deletions core/config/solver_config.hpp
Original file line number Diff line number Diff line change
@@ -1,34 +1,6 @@
/*******************************<GINKGO LICENSE>******************************
Copyright (c) 2017-2023, the Ginkgo authors
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************<GINKGO LICENSE>*******************************/
// SPDX-FileCopyrightText: 2017-2023 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_CONFIG_SOLVER_CONFIG_HPP_
#define GKO_CORE_CONFIG_SOLVER_CONFIG_HPP_
Expand All @@ -48,18 +20,14 @@ namespace config {
template <typename SolverFactory>
inline void common_solver_configure(SolverFactory& factory, const pnode& config,
const registry& context,
std::shared_ptr<const Executor> exec,
type_descriptor td_for_child)
{
SET_POINTER(factory, const LinOp, generated_preconditioner, config, context,
exec, td_for_child);
// handle parameter requires exec
// criteria and preconditioner are almost in each solver -> to another
// function.
SET_POINTER_VECTOR(factory, const stop::CriterionFactory, criteria, config,
context, exec, td_for_child);
SET_POINTER(factory, const LinOpFactory, preconditioner, config, context,
exec, td_for_child);
td_for_child);
SET_FACTORY_VECTOR(factory, const stop::CriterionFactory, criteria, config,
context, td_for_child);
SET_FACTORY(factory, const LinOpFactory, preconditioner, config, context,
td_for_child);
}


Expand Down
12 changes: 12 additions & 0 deletions core/solver/bicg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ginkgo/core/base/precision_dispatch.hpp>


#include "core/config/solver_config.hpp"
#include "core/solver/bicg_kernels.hpp"
#include "core/solver/solver_boilerplate.hpp"

Expand All @@ -32,6 +33,17 @@ GKO_REGISTER_OPERATION(step_2, bicg::step_2);
} // namespace bicg


template <typename ValueType>
typename Bicg<ValueType>::parameters_type Bicg<ValueType>::build_from_config(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
{
auto factory = solver::Bicg<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
return factory;
}


template <typename ValueType>
std::unique_ptr<LinOp> Bicg<ValueType>::transpose() const
{
Expand Down
13 changes: 13 additions & 0 deletions core/solver/bicgstab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ginkgo/core/solver/solver_base.hpp>


#include "core/config/solver_config.hpp"
#include "core/distributed/helpers.hpp"
#include "core/solver/bicgstab_kernels.hpp"
#include "core/solver/solver_boilerplate.hpp"
Expand All @@ -36,6 +37,18 @@ GKO_REGISTER_OPERATION(finalize, bicgstab::finalize);
} // namespace bicgstab


template <typename ValueType>
typename Bicgstab<ValueType>::parameters_type
Bicgstab<ValueType>::build_from_config(const config::pnode& config,
const config::registry& context,
config::type_descriptor td_for_child)
{
auto factory = solver::Bicgstab<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
return factory;
}


template <typename ValueType>
std::unique_ptr<LinOp> Bicgstab<ValueType>::transpose() const
{
Expand Down
35 changes: 35 additions & 0 deletions core/solver/cb_gmres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


#include "core/base/extended_float.hpp"
#include "core/config/solver_config.hpp"
#include "core/solver/cb_gmres_accessor.hpp"
#include "core/solver/cb_gmres_kernels.hpp"

Expand Down Expand Up @@ -158,6 +159,40 @@ struct helper<std::complex<T>> {
};


template <typename ValueType>
typename CbGmres<ValueType>::parameters_type
CbGmres<ValueType>::build_from_config(const config::pnode& config,
const config::registry& context,
config::type_descriptor td_for_child)
{
auto factory = solver::CbGmres<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
SET_VALUE(factory, size_type, krylov_dim, config);
if (config.contains("storage_precision")) {
auto get_storage_precision = [](std::string str) {
using gko::solver::cb_gmres::storage_precision;
if (str == "keep") {
return storage_precision::keep;
} else if (str == "reduce1") {
return storage_precision::reduce1;
} else if (str == "reduce2") {
return storage_precision::reduce2;
} else if (str == "integer") {
return storage_precision::integer;
} else if (str == "ireduce1") {
return storage_precision::ireduce1;
} else if (str == "ireduce2") {
return storage_precision::ireduce2;
}
GKO_INVALID_STATE("Wrong value for storage_precision");
};
factory.with_storage_precision(get_storage_precision(
config.at("storage_precision").get_data<std::string>()));
}
return factory;
}


template <typename ValueType>
void CbGmres<ValueType>::apply_impl(const LinOp* b, LinOp* x) const
{
Expand Down
31 changes: 2 additions & 29 deletions core/solver/cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,13 @@
#include <ginkgo/core/base/utils.hpp>


#include "core/config/solver_config.hpp"
#include "core/distributed/helpers.hpp"
#include "core/solver/cg_kernels.hpp"
#include "core/solver/solver_boilerplate.hpp"


namespace gko {
namespace config {


template <>
deferred_factory_parameter<gko::LinOpFactory>
build_from_config<static_cast<int>(LinOpFactoryType::Cg)>(
const pnode& config, const registry& context,
gko::config::type_descriptor td)
{
auto updated = update_type(config, td);
return dispatch<gko::LinOpFactory, gko::solver::Cg>(
updated.first, config, context, updated, value_type_list());
}


} // namespace config


namespace solver {
namespace cg {
namespace {
Expand All @@ -58,17 +41,7 @@ typename Cg<ValueType>::parameters_type Cg<ValueType>::build_from_config(
config::type_descriptor td_for_child)
{
auto factory = solver::Cg<ValueType>::build();
SET_POINTER(factory, const LinOp, generated_preconditioner, config, context,
td_for_child);
// handle parameter requires exec
// criteria and preconditioner are almost in each solver -> to another
// function.
SET_FACTORY_VECTOR(factory, const stop::CriterionFactory, criteria, config,
context, td_for_child);
SET_FACTORY(factory, const LinOpFactory, preconditioner, config, context,
td_for_child);
// can also handle preconditioner, criterion here if they are in
// context.
common_solver_configure(factory, config, context, td_for_child);
return factory;
}

Expand Down
12 changes: 12 additions & 0 deletions core/solver/cgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ginkgo/core/solver/solver_base.hpp>


#include "core/config/solver_config.hpp"
#include "core/distributed/helpers.hpp"
#include "core/solver/cgs_kernels.hpp"
#include "core/solver/solver_boilerplate.hpp"
Expand All @@ -35,6 +36,17 @@ GKO_REGISTER_OPERATION(step_3, cgs::step_3);
} // namespace cgs


template <typename ValueType>
typename Cgs<ValueType>::parameters_type Cgs<ValueType>::build_from_config(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
{
auto factory = solver::Cgs<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
return factory;
}


template <typename ValueType>
std::unique_ptr<LinOp> Cgs<ValueType>::transpose() const
{
Expand Down
17 changes: 17 additions & 0 deletions core/solver/direct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,28 @@
#include <ginkgo/core/solver/solver_base.hpp>


#include "core/config/config.hpp"


namespace gko {
namespace experimental {
namespace solver {


template <typename ValueType, typename IndexType>
typename Direct<ValueType, IndexType>::parameters_type
Direct<ValueType, IndexType>::build_from_config(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
{
auto factory = Direct<ValueType, IndexType>::build();
SET_VALUE(factory, size_type, num_rhs, config);
SET_FACTORY(factory, const LinOpFactory, factorization, config, context,
td_for_child);
return factory;
}


template <typename ValueType, typename IndexType>
std::unique_ptr<LinOp> Direct<ValueType, IndexType>::transpose() const
GKO_NOT_IMPLEMENTED;
Expand Down
12 changes: 12 additions & 0 deletions core/solver/fcg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ginkgo/core/base/utils.hpp>


#include "core/config/solver_config.hpp"
#include "core/distributed/helpers.hpp"
#include "core/solver/fcg_kernels.hpp"
#include "core/solver/solver_boilerplate.hpp"
Expand All @@ -33,6 +34,17 @@ GKO_REGISTER_OPERATION(step_2, fcg::step_2);
} // namespace fcg


template <typename ValueType>
typename Fcg<ValueType>::parameters_type Fcg<ValueType>::build_from_config(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
{
auto factory = solver::Fcg<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
return factory;
}


template <typename ValueType>
std::unique_ptr<LinOp> Fcg<ValueType>::transpose() const
{
Expand Down
13 changes: 13 additions & 0 deletions core/solver/gcr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <ginkgo/core/matrix/identity.hpp>


#include "core/config/solver_config.hpp"
#include "core/distributed/helpers.hpp"
#include "core/solver/gcr_kernels.hpp"
#include "core/solver/solver_boilerplate.hpp"
Expand All @@ -36,6 +37,18 @@ GKO_REGISTER_OPERATION(step_1, gcr::step_1);
} // namespace gcr


template <typename ValueType>
typename Gcr<ValueType>::parameters_type Gcr<ValueType>::build_from_config(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
{
auto factory = solver::Gcr<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
SET_VALUE(factory, size_type, krylov_dim, config);
return factory;
}


template <typename ValueType>
std::unique_ptr<LinOp> Gcr<ValueType>::transpose() const
{
Expand Down
14 changes: 14 additions & 0 deletions core/solver/gmres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ginkgo/core/matrix/identity.hpp>


#include "core/config/solver_config.hpp"
#include "core/distributed/helpers.hpp"
#include "core/solver/common_gmres_kernels.hpp"
#include "core/solver/gmres_kernels.hpp"
Expand All @@ -40,6 +41,19 @@ GKO_REGISTER_OPERATION(multi_axpy, gmres::multi_axpy);
} // namespace gmres


template <typename ValueType>
typename Gmres<ValueType>::parameters_type Gmres<ValueType>::build_from_config(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
{
auto factory = solver::Gmres<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
SET_VALUE(factory, size_type, krylov_dim, config);
SET_VALUE(factory, bool, flexible, config);
return factory;
}


template <typename ValueType>
std::unique_ptr<LinOp> Gmres<ValueType>::transpose() const
{
Expand Down
16 changes: 16 additions & 0 deletions core/solver/idr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ginkgo/core/solver/solver_base.hpp>


#include "core/config/solver_config.hpp"
#include "core/distributed/helpers.hpp"
#include "core/solver/idr_kernels.hpp"
#include "core/solver/solver_boilerplate.hpp"
Expand All @@ -35,6 +36,21 @@ GKO_REGISTER_OPERATION(compute_omega, idr::compute_omega);
} // namespace idr


template <typename ValueType>
typename Idr<ValueType>::parameters_type Idr<ValueType>::build_from_config(
const config::pnode& config, const config::registry& context,
config::type_descriptor td_for_child)
{
auto factory = solver::Idr<ValueType>::build();
common_solver_configure(factory, config, context, td_for_child);
SET_VALUE(factory, size_type, subspace_dim, config);
SET_VALUE(factory, remove_complex<ValueType>, kappa, config);
SET_VALUE(factory, bool, deterministic, config);
SET_VALUE(factory, bool, complex_subspace, config);
return factory;
}


template <typename ValueType>
std::unique_ptr<LinOp> Idr<ValueType>::transpose() const
{
Expand Down
Loading

0 comments on commit 70c9452

Please sign in to comment.