diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index 341983e72fd..623520fddda 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -23,6 +23,7 @@ target_sources(ginkgo config/config_helper.cpp config/property_tree.cpp config/registry.cpp + config/solver_config.cpp config/stop_config.cpp config/type_descriptor.cpp distributed/index_map.cpp diff --git a/core/config/config_helper.hpp b/core/config/config_helper.hpp index 798d3623856..30839297932 100644 --- a/core/config/config_helper.hpp +++ b/core/config/config_helper.hpp @@ -30,7 +30,21 @@ namespace config { * LinOpFactoryType enum is to avoid forward declaration, linopfactory header, * two template versions of parse */ -enum class LinOpFactoryType : int { Cg = 0 }; +enum class LinOpFactoryType : int { + Cg = 0, + Bicg, + Bicgstab, + Fcg, + Cgs, + Ir, + Idr, + Gcr, + Gmres, + CbGmres, + Direct, + LowerTrs, + UpperTrs +}; /** @@ -107,13 +121,29 @@ inline std::vector> parse_or_get_factory_vector( } +/** + * get_value gets the corresponding type value from config. + * + * This is specialization for bool type + */ +template +inline std::enable_if_t::value, bool> get_value( + const pnode& config) +{ + auto val = config.get_boolean(); + return val; +} + + /** * get_value gets the corresponding type value from config. * * This is specialization for integral type */ template -inline std::enable_if_t::value, IndexType> +inline std::enable_if_t::value && + !std::is_same::value, + IndexType> get_value(const pnode& config) { auto val = config.get_integer(); @@ -173,6 +203,29 @@ get_value(const pnode& config) } +/** + * get_value gets the corresponding type value from config. + * + * This is specialization for initial_guess_mode + */ +template +inline std::enable_if_t< + std::is_same::value, + solver::initial_guess_mode> +get_value(const pnode& config) +{ + auto val = config.get_string(); + if (val == "zero") { + return solver::initial_guess_mode::zero; + } else if (val == "rhs") { + return solver::initial_guess_mode::rhs; + } else if (val == "provided") { + return solver::initial_guess_mode::provided; + } + GKO_INVALID_STATE("Wrong value for initial_guess_mode"); +} + + } // namespace config } // namespace gko diff --git a/core/config/dispatch.hpp b/core/config/dispatch.hpp index c765150f72a..5bf5dc3273e 100644 --- a/core/config/dispatch.hpp +++ b/core/config/dispatch.hpp @@ -11,6 +11,7 @@ #include +#include #include #include #include @@ -106,6 +107,7 @@ deferred_factory_parameter dispatch( using value_type_list = syn::type_list, std::complex>; +using index_type_list = syn::type_list; } // namespace config } // namespace gko diff --git a/core/config/parse_macro.hpp b/core/config/parse_macro.hpp new file mode 100644 index 00000000000..cbc9438fbb7 --- /dev/null +++ b/core/config/parse_macro.hpp @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_CORE_CONFIG_PARSE_MACRO_HPP_ +#define GKO_CORE_CONFIG_PARSE_MACRO_HPP_ + + +#include +#include +#include + + +#include "core/config/config_helper.hpp" +#include "core/config/dispatch.hpp" +#include "core/config/type_descriptor_helper.hpp" + + +// for value_type only +#define GKO_PARSE_VALUE_TYPE(_type, _configurator) \ + template <> \ + deferred_factory_parameter \ + parse( \ + const gko::config::pnode& config, \ + const gko::config::registry& context, \ + const gko::config::type_descriptor& td) \ + { \ + auto updated = gko::config::update_type(config, td); \ + return gko::config::dispatch( \ + config, context, updated, \ + gko::config::make_type_selector(updated.get_value_typestr(), \ + gko::config::value_type_list())); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + + +// for value_type and index_type +#define GKO_PARSE_VALUE_AND_INDEX_TYPE(_type, _configurator) \ + template <> \ + deferred_factory_parameter \ + parse( \ + const gko::config::pnode& config, \ + const gko::config::registry& context, \ + const gko::config::type_descriptor& td) \ + { \ + auto updated = gko::config::update_type(config, td); \ + return gko::config::dispatch( \ + config, context, updated, \ + gko::config::make_type_selector(updated.get_value_typestr(), \ + gko::config::value_type_list()), \ + gko::config::make_type_selector(updated.get_index_typestr(), \ + gko::config::index_type_list())); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + + +#endif // GKO_CORE_CONFIG_PARSE_MACRO_HPP_ diff --git a/core/config/registry.cpp b/core/config/registry.cpp index 1113adb93f4..d38c497973b 100644 --- a/core/config/registry.cpp +++ b/core/config/registry.cpp @@ -18,7 +18,19 @@ namespace config { configuration_map generate_config_map() { - return {{"solver::Cg", parse}}; + return {{"solver::Cg", parse}, + {"solver::Bicg", parse}, + {"solver::Bicgstab", parse}, + {"solver::Fcg", parse}, + {"solver::Cgs", parse}, + {"solver::Ir", parse}, + {"solver::Idr", parse}, + {"solver::Gcr", parse}, + {"solver::Gmres", parse}, + {"solver::CbGmres", parse}, + {"solver::Direct", parse}, + {"solver::LowerTrs", parse}, + {"solver::UpperTrs", parse}}; } diff --git a/core/config/solver_config.cpp b/core/config/solver_config.cpp new file mode 100644 index 00000000000..9d09643a3ab --- /dev/null +++ b/core/config/solver_config.cpp @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#include "core/config/config_helper.hpp" +#include "core/config/dispatch.hpp" +#include "core/config/parse_macro.hpp" +#include "core/config/solver_config.hpp" + + +namespace gko { +namespace config { + + +GKO_PARSE_VALUE_TYPE(Cg, gko::solver::Cg); +GKO_PARSE_VALUE_TYPE(Bicg, gko::solver::Bicg); +GKO_PARSE_VALUE_TYPE(Bicgstab, gko::solver::Bicgstab); +GKO_PARSE_VALUE_TYPE(Cgs, gko::solver::Cgs); +GKO_PARSE_VALUE_TYPE(Fcg, gko::solver::Fcg); +GKO_PARSE_VALUE_TYPE(Ir, gko::solver::Ir); +GKO_PARSE_VALUE_TYPE(Idr, gko::solver::Idr); +GKO_PARSE_VALUE_TYPE(Gcr, gko::solver::Gcr); +GKO_PARSE_VALUE_TYPE(Gmres, gko::solver::Gmres); +GKO_PARSE_VALUE_TYPE(CbGmres, gko::solver::CbGmres); +GKO_PARSE_VALUE_AND_INDEX_TYPE(Direct, gko::experimental::solver::Direct); +GKO_PARSE_VALUE_AND_INDEX_TYPE(LowerTrs, gko::solver::LowerTrs); +GKO_PARSE_VALUE_AND_INDEX_TYPE(UpperTrs, gko::solver::UpperTrs); + + +} // namespace config +} // namespace gko diff --git a/core/config/solver_config.hpp b/core/config/solver_config.hpp new file mode 100644 index 00000000000..3c820541f2c --- /dev/null +++ b/core/config/solver_config.hpp @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_CORE_CONFIG_SOLVER_CONFIG_HPP_ +#define GKO_CORE_CONFIG_SOLVER_CONFIG_HPP_ + + +#include +#include + + +#include "core/config/config_helper.hpp" +#include "core/config/dispatch.hpp" + +namespace gko { +namespace config { + + +template +inline void common_solver_parse(SolverParam& params, const pnode& config, + const registry& context, + type_descriptor td_for_child) +{ + if (auto& obj = config.get("generated_preconditioner")) { + params.with_generated_preconditioner( + gko::config::get_stored_obj(obj, context)); + } + if (auto& obj = config.get("criteria")) { + params.with_criteria( + gko::config::parse_or_get_factory_vector< + const stop::CriterionFactory>(obj, context, td_for_child)); + } + if (auto& obj = config.get("preconditioner")) { + params.with_preconditioner( + gko::config::parse_or_get_factory( + obj, context, td_for_child)); + } +} + + +} // namespace config +} // namespace gko + +#endif // GKO_CORE_CONFIG_SOLVER_CONFIG_HPP_ diff --git a/core/config/stop_config.cpp b/core/config/stop_config.cpp index 63148cbfcd9..5b585924ee5 100644 --- a/core/config/stop_config.cpp +++ b/core/config/stop_config.cpp @@ -27,22 +27,22 @@ namespace config { deferred_factory_parameter configure_time( const pnode& config, const registry& context, const type_descriptor& td) { - auto factory = stop::Time::build(); + auto params = stop::Time::build(); if (auto& obj = config.get("time_limit")) { - factory.with_time_limit(gko::config::get_value(obj)); + params.with_time_limit(gko::config::get_value(obj)); } - return factory; + return params; } deferred_factory_parameter configure_iter( const pnode& config, const registry& context, const type_descriptor& td) { - auto factory = stop::Iteration::build(); + auto params = stop::Iteration::build(); if (auto& obj = config.get("max_iters")) { - factory.with_max_iters(gko::config::get_value(obj)); + params.with_max_iters(gko::config::get_value(obj)); } - return factory; + return params; } diff --git a/core/config/trisolver_config.hpp b/core/config/trisolver_config.hpp new file mode 100644 index 00000000000..b7889fdcef2 --- /dev/null +++ b/core/config/trisolver_config.hpp @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_CORE_CONFIG_TRISOLVER_CONFIG_HPP_ +#define GKO_CORE_CONFIG_TRISOLVER_CONFIG_HPP_ + + +#include +#include +#include + + +#include "core/config/config_helper.hpp" +#include "core/config/dispatch.hpp" + +namespace gko { +namespace config { + + +template +inline void common_trisolver_parse(SolverParam& params, const pnode& config, + const registry& context, + type_descriptor td_for_child) +{ + if (auto& obj = config.get("num_rhs")) { + params.with_num_rhs(gko::config::get_value(obj)); + } + if (auto& obj = config.get("unit_diagonal")) { + params.with_unit_diagonal(gko::config::get_value(obj)); + } + if (auto& obj = config.get("algorithm")) { + using gko::solver::trisolve_algorithm; + auto str = obj.get_string(); + if (str == "sparselib") { + params.with_algorithm(trisolve_algorithm::sparselib); + } else if (str == "syncfree") { + params.with_algorithm(trisolve_algorithm::syncfree); + } else { + GKO_INVALID_STATE("Wrong value for algorithm"); + } + } +} + + +} // namespace config +} // namespace gko + +#endif // GKO_CORE_CONFIG_TRISOLVER_CONFIG_HPP_ diff --git a/core/solver/bicg.cpp b/core/solver/bicg.cpp index 876509c893f..b5831c33ada 100644 --- a/core/solver/bicg.cpp +++ b/core/solver/bicg.cpp @@ -13,6 +13,7 @@ #include +#include "core/config/solver_config.hpp" #include "core/solver/bicg_kernels.hpp" #include "core/solver/solver_boilerplate.hpp" @@ -32,6 +33,17 @@ GKO_REGISTER_OPERATION(step_2, bicg::step_2); } // namespace bicg +template +typename Bicg::parameters_type Bicg::parse( + const config::pnode& config, const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = solver::Bicg::build(); + common_solver_parse(params, config, context, td_for_child); + return params; +} + + template std::unique_ptr Bicg::transpose() const { diff --git a/core/solver/bicgstab.cpp b/core/solver/bicgstab.cpp index 2d3f55f28d6..c6ae33918a1 100644 --- a/core/solver/bicgstab.cpp +++ b/core/solver/bicgstab.cpp @@ -14,6 +14,7 @@ #include +#include "core/config/solver_config.hpp" #include "core/distributed/helpers.hpp" #include "core/solver/bicgstab_kernels.hpp" #include "core/solver/solver_boilerplate.hpp" @@ -36,6 +37,17 @@ GKO_REGISTER_OPERATION(finalize, bicgstab::finalize); } // namespace bicgstab +template +typename Bicgstab::parameters_type Bicgstab::parse( + const config::pnode& config, const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = solver::Bicgstab::build(); + common_solver_parse(params, config, context, td_for_child); + return params; +} + + template std::unique_ptr Bicgstab::transpose() const { diff --git a/core/solver/cb_gmres.cpp b/core/solver/cb_gmres.cpp index 0e15206197c..84717b15f61 100644 --- a/core/solver/cb_gmres.cpp +++ b/core/solver/cb_gmres.cpp @@ -20,6 +20,7 @@ #include "core/base/extended_float.hpp" +#include "core/config/solver_config.hpp" #include "core/solver/cb_gmres_accessor.hpp" #include "core/solver/cb_gmres_kernels.hpp" @@ -158,6 +159,40 @@ struct helper> { }; +template +typename CbGmres::parameters_type CbGmres::parse( + const config::pnode& config, const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = solver::CbGmres::build(); + common_solver_parse(params, config, context, td_for_child); + if (auto& obj = config.get("krylov_dim")) { + params.with_krylov_dim(gko::config::get_value(obj)); + } + if (auto& obj = config.get("storage_precision")) { + auto get_storage_precision = [](std::string str) { + using gko::solver::cb_gmres::storage_precision; + if (str == "keep") { + return storage_precision::keep; + } else if (str == "reduce1") { + return storage_precision::reduce1; + } else if (str == "reduce2") { + return storage_precision::reduce2; + } else if (str == "integer") { + return storage_precision::integer; + } else if (str == "ireduce1") { + return storage_precision::ireduce1; + } else if (str == "ireduce2") { + return storage_precision::ireduce2; + } + GKO_INVALID_STATE("Wrong value for storage_precision"); + }; + params.with_storage_precision(get_storage_precision(obj.get_string())); + } + return params; +} + + template void CbGmres::apply_impl(const LinOp* b, LinOp* x) const { diff --git a/core/solver/cg.cpp b/core/solver/cg.cpp index 71e5fcfbb3b..f83faf7e20f 100644 --- a/core/solver/cg.cpp +++ b/core/solver/cg.cpp @@ -12,36 +12,15 @@ #include #include #include -#include -#include -#include "core/config/config_helper.hpp" -#include "core/config/dispatch.hpp" -#include "core/config/type_descriptor_helper.hpp" +#include "core/config/solver_config.hpp" #include "core/distributed/helpers.hpp" #include "core/solver/cg_kernels.hpp" #include "core/solver/solver_boilerplate.hpp" namespace gko { -namespace config { - - -template <> -deferred_factory_parameter parse( - const pnode& config, const registry& context, const type_descriptor& td) -{ - auto updated = update_type(config, td); - return dispatch( - config, context, updated, - make_type_selector(updated.get_value_typestr(), value_type_list())); -} - - -} // namespace config - - namespace solver { namespace cg { namespace { @@ -62,21 +41,7 @@ typename Cg::parameters_type Cg::parse( const config::type_descriptor& td_for_child) { auto params = solver::Cg::build(); - // The following will be moved to the common solver function in another pr - if (auto& obj = config.get("generated_preconditioner")) { - params.with_generated_preconditioner( - gko::config::get_stored_obj(obj, context)); - } - if (auto& obj = config.get("criteria")) { - params.with_criteria( - gko::config::parse_or_get_factory_vector< - const stop::CriterionFactory>(obj, context, td_for_child)); - } - if (auto& obj = config.get("preconditioner")) { - params.with_preconditioner( - gko::config::parse_or_get_factory( - obj, context, td_for_child)); - } + common_solver_parse(params, config, context, td_for_child); return params; } diff --git a/core/solver/cgs.cpp b/core/solver/cgs.cpp index 9a06cca2439..6bb41338f77 100644 --- a/core/solver/cgs.cpp +++ b/core/solver/cgs.cpp @@ -14,6 +14,7 @@ #include +#include "core/config/solver_config.hpp" #include "core/distributed/helpers.hpp" #include "core/solver/cgs_kernels.hpp" #include "core/solver/solver_boilerplate.hpp" @@ -35,6 +36,17 @@ GKO_REGISTER_OPERATION(step_3, cgs::step_3); } // namespace cgs +template +typename Cgs::parameters_type Cgs::parse( + const config::pnode& config, const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = solver::Cgs::build(); + common_solver_parse(params, config, context, td_for_child); + return params; +} + + template std::unique_ptr Cgs::transpose() const { diff --git a/core/solver/direct.cpp b/core/solver/direct.cpp index 7b55dc38bc6..d540aa584f0 100644 --- a/core/solver/direct.cpp +++ b/core/solver/direct.cpp @@ -13,11 +13,33 @@ #include +#include "core/config/config_helper.hpp" + + namespace gko { namespace experimental { namespace solver { +template +typename Direct::parameters_type +Direct::parse(const config::pnode& config, + const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = Direct::build(); + if (auto& obj = config.get("num_rhs")) { + params.with_num_rhs(gko::config::get_value(obj)); + } + if (auto& obj = config.get("factorization")) { + params.with_factorization( + gko::config::parse_or_get_factory( + obj, context, td_for_child)); + } + return params; +} + + template std::unique_ptr Direct::transpose() const GKO_NOT_IMPLEMENTED; diff --git a/core/solver/fcg.cpp b/core/solver/fcg.cpp index ad4e6069c58..5966664c14d 100644 --- a/core/solver/fcg.cpp +++ b/core/solver/fcg.cpp @@ -13,6 +13,7 @@ #include +#include "core/config/solver_config.hpp" #include "core/distributed/helpers.hpp" #include "core/solver/fcg_kernels.hpp" #include "core/solver/solver_boilerplate.hpp" @@ -33,6 +34,17 @@ GKO_REGISTER_OPERATION(step_2, fcg::step_2); } // namespace fcg +template +typename Fcg::parameters_type Fcg::parse( + const config::pnode& config, const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = solver::Fcg::build(); + common_solver_parse(params, config, context, td_for_child); + return params; +} + + template std::unique_ptr Fcg::transpose() const { diff --git a/core/solver/gcr.cpp b/core/solver/gcr.cpp index 910b6230119..24fb36aa42b 100644 --- a/core/solver/gcr.cpp +++ b/core/solver/gcr.cpp @@ -16,6 +16,7 @@ #include +#include "core/config/solver_config.hpp" #include "core/distributed/helpers.hpp" #include "core/solver/gcr_kernels.hpp" #include "core/solver/solver_boilerplate.hpp" @@ -36,6 +37,20 @@ GKO_REGISTER_OPERATION(step_1, gcr::step_1); } // namespace gcr +template +typename Gcr::parameters_type Gcr::parse( + const config::pnode& config, const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = solver::Gcr::build(); + common_solver_parse(params, config, context, td_for_child); + if (auto& obj = config.get("krylov_dim")) { + params.with_krylov_dim(gko::config::get_value(obj)); + } + return params; +} + + template std::unique_ptr Gcr::transpose() const { diff --git a/core/solver/gmres.cpp b/core/solver/gmres.cpp index f4e80130d90..b261cf754eb 100644 --- a/core/solver/gmres.cpp +++ b/core/solver/gmres.cpp @@ -17,6 +17,7 @@ #include +#include "core/config/solver_config.hpp" #include "core/distributed/helpers.hpp" #include "core/solver/common_gmres_kernels.hpp" #include "core/solver/gmres_kernels.hpp" @@ -40,6 +41,23 @@ GKO_REGISTER_OPERATION(multi_axpy, gmres::multi_axpy); } // namespace gmres +template +typename Gmres::parameters_type Gmres::parse( + const config::pnode& config, const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = solver::Gmres::build(); + common_solver_parse(params, config, context, td_for_child); + if (auto& obj = config.get("krylov_dim")) { + params.with_krylov_dim(gko::config::get_value(obj)); + } + if (auto& obj = config.get("flexible")) { + params.with_flexible(gko::config::get_value(obj)); + } + return params; +} + + template std::unique_ptr Gmres::transpose() const { diff --git a/core/solver/idr.cpp b/core/solver/idr.cpp index d65b8b5f7c3..9085876a85a 100644 --- a/core/solver/idr.cpp +++ b/core/solver/idr.cpp @@ -13,6 +13,7 @@ #include +#include "core/config/solver_config.hpp" #include "core/distributed/helpers.hpp" #include "core/solver/idr_kernels.hpp" #include "core/solver/solver_boilerplate.hpp" @@ -35,6 +36,30 @@ GKO_REGISTER_OPERATION(compute_omega, idr::compute_omega); } // namespace idr +template +typename Idr::parameters_type Idr::parse( + const config::pnode& config, const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = solver::Idr::build(); + common_solver_parse(params, config, context, td_for_child); + if (auto& obj = config.get("subspace_dim")) { + params.with_subspace_dim(gko::config::get_value(obj)); + } + if (auto& obj = config.get("kappa")) { + params.with_kappa( + gko::config::get_value>(obj)); + } + if (auto& obj = config.get("deterministic")) { + params.with_deterministic(gko::config::get_value(obj)); + } + if (auto& obj = config.get("complex_subspace")) { + params.with_complex_subspace(gko::config::get_value(obj)); + } + return params; +} + + template std::unique_ptr Idr::transpose() const { diff --git a/core/solver/ir.cpp b/core/solver/ir.cpp index b8258d27dd9..16152dc63e9 100644 --- a/core/solver/ir.cpp +++ b/core/solver/ir.cpp @@ -10,6 +10,7 @@ #include +#include "core/config/config_helper.hpp" #include "core/distributed/helpers.hpp" #include "core/solver/ir_kernels.hpp" #include "core/solver/solver_base.hpp" @@ -29,6 +30,37 @@ GKO_REGISTER_OPERATION(initialize, ir::initialize); } // namespace ir +template +typename Ir::parameters_type Ir::parse( + const config::pnode& config, const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = solver::Ir::build(); + if (auto& obj = config.get("criteria")) { + params.with_criteria( + gko::config::parse_or_get_factory_vector< + const stop::CriterionFactory>(obj, context, td_for_child)); + } + if (auto& obj = config.get("solver")) { + params.with_solver( + gko::config::parse_or_get_factory( + obj, context, td_for_child)); + } + if (auto& obj = config.get("generated_solver")) { + params.with_generated_solver( + gko::config::get_stored_obj(obj, context)); + } + if (auto& obj = config.get("relaxation_factor")) { + params.with_relaxation_factor(gko::config::get_value(obj)); + } + if (auto& obj = config.get("default_initial_guess")) { + params.with_default_initial_guess( + gko::config::get_value(obj)); + } + return params; +} + + template void Ir::set_solver(std::shared_ptr new_solver) { diff --git a/core/solver/lower_trs.cpp b/core/solver/lower_trs.cpp index c4a5454c76a..e36ec98f8fb 100644 --- a/core/solver/lower_trs.cpp +++ b/core/solver/lower_trs.cpp @@ -14,6 +14,8 @@ #include +#include "core/config/config_helper.hpp" +#include "core/config/trisolver_config.hpp" #include "core/solver/lower_trs_kernels.hpp" @@ -33,6 +35,18 @@ GKO_REGISTER_OPERATION(solve, lower_trs::solve); } // namespace lower_trs +template +typename LowerTrs::parameters_type +LowerTrs::parse( + const config::pnode& config, const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = LowerTrs::build(); + common_trisolver_parse(params, config, context, td_for_child); + return params; +} + + template LowerTrs::LowerTrs(const LowerTrs& other) : EnableLinOp(other.get_executor()) diff --git a/core/solver/upper_trs.cpp b/core/solver/upper_trs.cpp index 5e77cc6061f..5a854bddf1e 100644 --- a/core/solver/upper_trs.cpp +++ b/core/solver/upper_trs.cpp @@ -14,6 +14,8 @@ #include +#include "core/config/config_helper.hpp" +#include "core/config/trisolver_config.hpp" #include "core/solver/upper_trs_kernels.hpp" @@ -33,6 +35,18 @@ GKO_REGISTER_OPERATION(solve, upper_trs::solve); } // namespace upper_trs +template +typename UpperTrs::parameters_type +UpperTrs::parse( + const config::pnode& config, const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = UpperTrs::build(); + common_trisolver_parse(params, config, context, td_for_child); + return params; +} + + template UpperTrs::UpperTrs(const UpperTrs& other) : EnableLinOp(other.get_executor()) diff --git a/core/test/config/CMakeLists.txt b/core/test/config/CMakeLists.txt index a8783cd4a20..3bda9fa0ff4 100644 --- a/core/test/config/CMakeLists.txt +++ b/core/test/config/CMakeLists.txt @@ -1,3 +1,4 @@ ginkgo_create_test(config) ginkgo_create_test(property_tree) ginkgo_create_test(registry) +ginkgo_create_test(solver) diff --git a/core/test/config/solver.cpp b/core/test/config/solver.cpp new file mode 100644 index 00000000000..b40c4dc1781 --- /dev/null +++ b/core/test/config/solver.cpp @@ -0,0 +1,532 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include + + +#include + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#include "core/config/config_helper.hpp" +#include "core/config/registry_accessor.hpp" +#include "core/test/utils.hpp" + + +using namespace gko::config; + + +using DummySolver = gko::solver::Cg; +using DummyStop = gko::stop::Iteration; + + +template +struct SolverConfigTest { + using changed_solver_type = ChangedSolverType; + using default_solver_type = DefaultSolverType; + using solver_config_test = SolverConfigTest; + + static pnode::map_type setup_base() { return pnode::map_type{}; } + + static void change_template(pnode::map_type& config_map) + { + config_map["value_type"] = pnode{"float32"}; + } + + template + static void set(pnode::map_type& config_map, ParamType& param, registry reg, + std::shared_ptr exec) + { + config_map["generated_preconditioner"] = pnode{"linop"}; + param.with_generated_preconditioner( + detail::registry_accessor::get_data(reg, "linop")); + if (from_reg) { + config_map["criteria"] = pnode{"criterion_factory"}; + param.with_criteria( + detail::registry_accessor::get_data< + gko::stop::CriterionFactory>(reg, "criterion_factory")); + config_map["preconditioner"] = pnode{"linop_factory"}; + param.with_preconditioner( + detail::registry_accessor::get_data( + reg, "linop_factory")); + } else { + config_map["criteria"] = pnode{{{"type", pnode{"Iteration"}}}}; + param.with_criteria(DummyStop::build().on(exec)); + config_map["preconditioner"] = + pnode{{{"type", pnode{"solver::Cg"}}, + {"value_type", pnode{"float64"}}}}; + param.with_preconditioner(DummySolver::build().on(exec)); + } + } + + template + static void validate(gko::LinOpFactory* result, AnswerType* answer) + { + auto res_param = gko::as(result)->get_parameters(); + auto ans_param = answer->get_parameters(); + + ASSERT_EQ(res_param.generated_preconditioner, + ans_param.generated_preconditioner); + if (from_reg) { + ASSERT_EQ(res_param.criteria, ans_param.criteria); + ASSERT_EQ(res_param.preconditioner, ans_param.preconditioner); + } else { + ASSERT_NE( + std::dynamic_pointer_cast( + res_param.criteria.at(0)), + nullptr); + ASSERT_NE( + std::dynamic_pointer_cast( + res_param.preconditioner), + nullptr); + } + } +}; + + +struct Cg : SolverConfigTest, gko::solver::Cg> { + static pnode::map_type setup_base() + { + return {{"type", pnode{"solver::Cg"}}}; + } +}; + + +struct Cgs + : SolverConfigTest, gko::solver::Cgs> { + static pnode::map_type setup_base() + { + return {{"type", pnode{"solver::Cgs"}}}; + } +}; + + +struct Fcg + : SolverConfigTest, gko::solver::Fcg> { + static pnode::map_type setup_base() + { + return {{"type", pnode{"solver::Fcg"}}}; + } +}; + + +struct Bicg + : SolverConfigTest, gko::solver::Bicg> { + static pnode::map_type setup_base() + { + return {{"type", pnode{"solver::Bicg"}}}; + } +}; + + +struct Bicgstab : SolverConfigTest, + gko::solver::Bicgstab> { + static pnode::map_type setup_base() + { + return {{"type", pnode{"solver::Bicgstab"}}}; + } +}; + + +struct Ir : SolverConfigTest, gko::solver::Ir> { + static pnode::map_type setup_base() + { + return {{"type", pnode{"solver::Ir"}}}; + } + + template + static void set(pnode::map_type& config_map, ParamType& param, registry reg, + std::shared_ptr exec) + { + config_map["generated_solver"] = pnode{"linop"}; + param.with_generated_solver( + detail::registry_accessor::get_data(reg, "linop")); + config_map["relaxation_factor"] = pnode{1.2}; + param.with_relaxation_factor(decltype(param.relaxation_factor){1.2}); + config_map["default_initial_guess"] = pnode{"zero"}; + param.with_default_initial_guess(gko::solver::initial_guess_mode::zero); + if (from_reg) { + config_map["criteria"] = pnode{"criterion_factory"}; + param.with_criteria( + detail::registry_accessor::get_data< + gko::stop::CriterionFactory>(reg, "criterion_factory")); + config_map["solver"] = pnode{"linop_factory"}; + param.with_solver( + detail::registry_accessor::get_data( + reg, "linop_factory")); + } else { + config_map["criteria"] = pnode{{{"type", pnode{"Iteration"}}}}; + param.with_criteria(DummyStop::build().on(exec)); + config_map["solver"] = pnode{{{"type", pnode{"solver::Cg"}}, + {"value_type", pnode{"float64"}}}}; + param.with_solver(DummySolver::build().on(exec)); + } + } + + template + static void validate(gko::LinOpFactory* result, AnswerType* answer) + { + auto res_param = gko::as(result)->get_parameters(); + auto ans_param = answer->get_parameters(); + + ASSERT_EQ(res_param.generated_solver, ans_param.generated_solver); + ASSERT_EQ(res_param.relaxation_factor, ans_param.relaxation_factor); + ASSERT_EQ(res_param.default_initial_guess, + ans_param.default_initial_guess); + if (from_reg) { + ASSERT_EQ(res_param.criteria, ans_param.criteria); + ASSERT_EQ(res_param.solver, ans_param.solver); + } else { + ASSERT_NE( + std::dynamic_pointer_cast( + res_param.criteria.at(0)), + nullptr); + ASSERT_NE( + std::dynamic_pointer_cast( + res_param.solver), + nullptr); + } + } +}; + + +struct Idr + : SolverConfigTest, gko::solver::Idr> { + static pnode::map_type setup_base() + { + return {{"type", pnode{"solver::Idr"}}}; + } + + template + static void set(pnode::map_type& config_map, ParamType& param, registry reg, + std::shared_ptr exec) + { + solver_config_test::template set(config_map, param, reg, + exec); + config_map["subspace_dim"] = pnode{3}; + param.with_subspace_dim(3u); + config_map["kappa"] = pnode{0.9}; + param.with_kappa(decltype(param.kappa){0.9}); + config_map["deterministic"] = pnode{true}; + param.with_deterministic(true); + config_map["complex_subspace"] = pnode{true}; + param.with_complex_subspace(true); + } + + template + static void validate(gko::LinOpFactory* result, AnswerType* answer) + { + auto res_param = gko::as(result)->get_parameters(); + auto ans_param = answer->get_parameters(); + + solver_config_test::template validate(result, answer); + ASSERT_EQ(res_param.subspace_dim, ans_param.subspace_dim); + ASSERT_EQ(res_param.kappa, ans_param.kappa); + ASSERT_EQ(res_param.deterministic, ans_param.deterministic); + ASSERT_EQ(res_param.complex_subspace, ans_param.complex_subspace); + } +}; + + +struct Gcr + : SolverConfigTest, gko::solver::Gcr> { + static pnode::map_type setup_base() + { + return {{"type", pnode{"solver::Gcr"}}}; + } + + template + static void set(pnode::map_type& config_map, ParamType& param, registry reg, + std::shared_ptr exec) + { + solver_config_test::template set(config_map, param, reg, + exec); + config_map["krylov_dim"] = pnode{3}; + param.with_krylov_dim(3u); + } + + template + static void validate(gko::LinOpFactory* result, AnswerType* answer) + { + auto res_param = gko::as(result)->get_parameters(); + auto ans_param = answer->get_parameters(); + + solver_config_test::template validate(result, answer); + ASSERT_EQ(res_param.krylov_dim, ans_param.krylov_dim); + } +}; + + +struct Gmres + : SolverConfigTest, gko::solver::Gmres> { + static pnode::map_type setup_base() + { + return {{"type", pnode{"solver::Gmres"}}}; + } + + template + static void set(pnode::map_type& config_map, ParamType& param, registry reg, + std::shared_ptr exec) + { + solver_config_test::template set(config_map, param, reg, + exec); + config_map["krylov_dim"] = pnode{3}; + param.with_krylov_dim(3u); + config_map["flexible"] = pnode{true}; + param.with_flexible(true); + } + + template + static void validate(gko::LinOpFactory* result, AnswerType* answer) + { + auto res_param = gko::as(result)->get_parameters(); + auto ans_param = answer->get_parameters(); + + solver_config_test::template validate(result, answer); + ASSERT_EQ(res_param.krylov_dim, ans_param.krylov_dim); + ASSERT_EQ(res_param.flexible, ans_param.flexible); + } +}; + + +struct CbGmres : SolverConfigTest, + gko::solver::CbGmres> { + static pnode::map_type setup_base() + { + return {{"type", pnode{"solver::CbGmres"}}}; + } + + template + static void set(pnode::map_type& config_map, ParamType& param, registry reg, + std::shared_ptr exec) + { + solver_config_test::template set(config_map, param, reg, + exec); + config_map["krylov_dim"] = pnode{3}; + param.with_krylov_dim(3u); + config_map["storage_precision"] = pnode{"reduce2"}; + param.with_storage_precision( + gko::solver::cb_gmres::storage_precision::reduce2); + } + + template + static void validate(gko::LinOpFactory* result, AnswerType* answer) + { + auto res_param = gko::as(result)->get_parameters(); + auto ans_param = answer->get_parameters(); + + solver_config_test::template validate(result, answer); + ASSERT_EQ(res_param.krylov_dim, ans_param.krylov_dim); + ASSERT_EQ(res_param.storage_precision, ans_param.storage_precision); + } +}; + + +struct Direct + : SolverConfigTest, + gko::experimental::solver::Direct> { + static pnode::map_type setup_base() + { + return {{"type", pnode{"solver::Direct"}}}; + } + + static void change_template(pnode::map_type& config_map) + { + config_map["value_type"] = pnode{"float32"}; + } + + template + static void set(pnode::map_type& config_map, ParamType& param, registry reg, + std::shared_ptr exec) + { + config_map["num_rhs"] = pnode{3}; + param.with_num_rhs(3u); + if (from_reg) { + config_map["factorization"] = pnode{"linop_factory"}; + param.with_factorization( + detail::registry_accessor::get_data( + reg, "linop_factory")); + } else { + config_map["factorization"] = + pnode{{{"type", pnode{"solver::Cg"}}, + {"value_type", pnode{"float64"}}}}; + param.with_factorization(DummySolver::build().on(exec)); + } + } + + template + static void validate(gko::LinOpFactory* result, AnswerType* answer) + { + auto res_param = gko::as(result)->get_parameters(); + auto ans_param = answer->get_parameters(); + + ASSERT_EQ(res_param.num_rhs, ans_param.num_rhs); + if (from_reg) { + ASSERT_EQ(res_param.factorization, ans_param.factorization); + } else { + ASSERT_NE( + std::dynamic_pointer_cast( + res_param.factorization), + nullptr); + } + } +}; + + +template