Skip to content

Commit

Permalink
remove fixed_coarsening, selector, stop::. update doc and rename
Browse files Browse the repository at this point in the history
Co-authored-by: Marcel Koch <[email protected]>
Co-authored-by: Thomas Grützmacher <[email protected]>
  • Loading branch information
3 people committed May 28, 2024
1 parent 1dcb908 commit 528d27e
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 186 deletions.
3 changes: 1 addition & 2 deletions core/config/config_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ enum class LinOpFactoryType : int {
Isai,
Jacobi,
Multigrid,
Pgm,
FixedCoarsening
Pgm
};


Expand Down
7 changes: 0 additions & 7 deletions core/config/multigrid_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,10 @@
//
// SPDX-License-Identifier: BSD-3-Clause

#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/multigrid/fixed_coarsening.hpp>
#include <ginkgo/core/multigrid/pgm.hpp>


#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"
#include "core/config/parse_macro.hpp"
#include "core/config/type_descriptor_helper.hpp"


namespace gko {
Expand Down
4 changes: 1 addition & 3 deletions core/config/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ configuration_map generate_config_map()
{"preconditioner::Isai", parse<LinOpFactoryType::Isai>},
{"preconditioner::Jacobi", parse<LinOpFactoryType::Jacobi>},
{"solver::Multigrid", parse<LinOpFactoryType::Multigrid>},
{"multigrid::Pgm", parse<LinOpFactoryType::Pgm>},
{"multigrid::FixedCoarsening",
parse<LinOpFactoryType::FixedCoarsening>}};
{"multigrid::Pgm", parse<LinOpFactoryType::Pgm>}};
}


Expand Down
15 changes: 0 additions & 15 deletions core/multigrid/fixed_coarsening.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,6 @@ GKO_REGISTER_OPERATION(fill_seq_array, components::fill_seq_array);
} // namespace fixed_coarsening


template <typename ValueType, typename IndexType>
typename FixedCoarsening<ValueType, IndexType>::parameters_type
FixedCoarsening<ValueType, IndexType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = FixedCoarsening<ValueType, IndexType>::build();
// TODO: ARRAY
if (auto& obj = config.get("skip_sorting")) {
factory.with_skip_sorting(gko::config::get_value<bool>(obj));
}
return factory;
}


template <typename ValueType, typename IndexType>
void FixedCoarsening<ValueType, IndexType>::generate()
{
Expand Down
13 changes: 7 additions & 6 deletions core/multigrid/pgm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,26 +141,27 @@ std::shared_ptr<matrix::Csr<ValueType, IndexType>> generate_coarse(
} // namespace


template <typename ValueType, typename IndexType>
typename Pgm<ValueType, IndexType>::parameters_type
Pgm<ValueType, IndexType>::parse(const config::pnode& config,
const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = Pgm<ValueType, IndexType>::build();
auto params = Pgm<ValueType, IndexType>::build();
if (auto& obj = config.get("max_iterations")) {
factory.with_max_iterations(gko::config::get_value<unsigned>(obj));
params.with_max_iterations(gko::config::get_value<unsigned>(obj));
}
if (auto& obj = config.get("max_unassigned_ratio")) {
factory.with_max_unassigned_ratio(gko::config::get_value<double>(obj));
params.with_max_unassigned_ratio(gko::config::get_value<double>(obj));
}
if (auto& obj = config.get("deterministic")) {
factory.with_deterministic(gko::config::get_value<bool>(obj));
params.with_deterministic(gko::config::get_value<bool>(obj));
}
if (auto& obj = config.get("skip_sorting")) {
factory.with_skip_sorting(gko::config::get_value<bool>(obj));
params.with_skip_sorting(gko::config::get_value<bool>(obj));
}

return factory;
return params;
}


Expand Down
91 changes: 35 additions & 56 deletions core/solver/multigrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

#include "core/base/dispatch_helper.hpp"
#include "core/components/fill_array_kernels.hpp"
#include "core/config/config.hpp"
#include "core/config/config_helper.hpp"
#include "core/distributed/helpers.hpp"
#include "core/solver/ir_kernels.hpp"
Expand Down Expand Up @@ -594,115 +593,95 @@ void MultigridState::run_cycle(multigrid::cycle cycle, size_type level,
} // namespace multigrid


std::function<size_type(const size_type, const LinOp*)> get_selector(
std::string key)
{
static std::map<std::string,
std::function<size_type(const size_type, const LinOp*)>>
selector_map{
{{"first_for_top", [](const size_type level, const LinOp*) {
return (level == 0) ? 0 : 1;
}}}};
return selector_map.at(key);
}


typename Multigrid::parameters_type Multigrid::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = Multigrid::build();
auto params = Multigrid::build();

if (auto& obj = config.get("criteria")) {
factory.with_criteria(
gko::config::get_factory_vector<const stop::CriterionFactory>(
params.with_criteria(
config::parse_or_get_factory_vector<const stop::CriterionFactory>(
obj, context, td_for_child));
}
if (auto& obj = config.get("mg_level")) {
factory.with_mg_level(
gko::config::get_factory_vector<const gko::LinOpFactory>(
params.with_mg_level(
config::parse_or_get_factory_vector<const gko::LinOpFactory>(
obj, context, td_for_child));
}
if (auto& obj = config.get("level_selector")) {
factory.with_level_selector(get_selector(obj.get_string()));
}
if (auto& obj = config.get("pre_smoother")) {
factory.with_pre_smoother(
gko::config::get_factory_vector<const LinOpFactory>(obj, context,
td_for_child));
params.with_pre_smoother(
config::parse_or_get_factory_vector<const LinOpFactory>(
obj, context, td_for_child));
}
if (auto& obj = config.get("post_smoother")) {
factory.with_post_smoother(
gko::config::get_factory_vector<const LinOpFactory>(obj, context,
td_for_child));
params.with_post_smoother(
config::parse_or_get_factory_vector<const LinOpFactory>(
obj, context, td_for_child));
}
if (auto& obj = config.get("mid_smoother")) {
factory.with_mid_smoother(
gko::config::get_factory_vector<const LinOpFactory>(obj, context,
td_for_child));
params.with_mid_smoother(
config::parse_or_get_factory_vector<const LinOpFactory>(
obj, context, td_for_child));
}
if (auto& obj = config.get("post_uses_pre")) {
factory.with_post_uses_pre(gko::config::get_value<bool>(obj));
params.with_post_uses_pre(gko::config::get_value<bool>(obj));
}
if (auto& obj = config.get("mid_case")) {
auto str = obj.get_string();
if (str == "both") {
factory.with_mid_case(multigrid::mid_smooth_type::both);
params.with_mid_case(multigrid::mid_smooth_type::both);
} else if (str == "post_smoother") {
factory.with_mid_case(multigrid::mid_smooth_type::post_smoother);
params.with_mid_case(multigrid::mid_smooth_type::post_smoother);
} else if (str == "pre_smoother") {
factory.with_mid_case(multigrid::mid_smooth_type::pre_smoother);
params.with_mid_case(multigrid::mid_smooth_type::pre_smoother);
} else if (str == "standalone") {
factory.with_mid_case(multigrid::mid_smooth_type::standalone);
params.with_mid_case(multigrid::mid_smooth_type::standalone);
} else {
GKO_INVALID_STATE("Not valid mid_smooth_type value");
GKO_INVALID_CONFIG_VALUE("mid_smooth_type", str);
}
}
if (auto& obj = config.get("max_levels")) {
factory.with_max_levels(gko::config::get_value<size_type>(obj));
params.with_max_levels(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("min_coarse_rows")) {
factory.with_min_coarse_rows(gko::config::get_value<size_type>(obj));
params.with_min_coarse_rows(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("coarsest_solver")) {
factory.with_coarsest_solver(
gko::config::get_factory_vector<const LinOpFactory>(obj, context,
td_for_child));
}
if (auto& obj = config.get("solver_selector")) {
auto str = obj.get_string();
factory.with_solver_selector(get_selector(str));
params.with_coarsest_solver(
config::parse_or_get_factory_vector<const LinOpFactory>(
obj, context, td_for_child));
}
if (auto& obj = config.get("cycle")) {
auto str = obj.get_string();
if (str == "v") {
factory.with_cycle(multigrid::cycle::v);
params.with_cycle(multigrid::cycle::v);
} else if (str == "w") {
factory.with_cycle(multigrid::cycle::w);
params.with_cycle(multigrid::cycle::w);
} else if (str == "f") {
factory.with_cycle(multigrid::cycle::f);
params.with_cycle(multigrid::cycle::f);
} else {
GKO_INVALID_STATE("Not valid cycle value");
GKO_INVALID_CONFIG_VALUE("cycle", str);
}
}
if (auto& obj = config.get("kcycle_base")) {
factory.with_kcycle_base(gko::config::get_value<size_type>(obj));
params.with_kcycle_base(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("kcycle_rel_tol")) {
factory.with_kcycle_rel_tol(gko::config::get_value<double>(obj));
params.with_kcycle_rel_tol(gko::config::get_value<double>(obj));
}
if (auto& obj = config.get("smoother_relax")) {
factory.with_smoother_relax(
params.with_smoother_relax(
gko::config::get_value<std::complex<double>>(obj));
}
if (auto& obj = config.get("smoother_iters")) {
factory.with_smoother_iters(gko::config::get_value<size_type>(obj));
params.with_smoother_iters(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("default_initial_guess")) {
factory.with_default_initial_guess(
params.with_default_initial_guess(
gko::config::get_value<solver::initial_guess_mode>(obj));
}
return factory;
return params;
}


Expand Down
Loading

0 comments on commit 528d27e

Please sign in to comment.