Skip to content

Commit

Permalink
Merge Add solver config except for the multigrid
Browse files Browse the repository at this point in the history
This PR adds the solver file config except for the multigrid part

Related PR: #1395
  • Loading branch information
yhmtsai authored May 27, 2024
2 parents eeaa7aa + 569148f commit 591ace6
Show file tree
Hide file tree
Showing 37 changed files with 1,282 additions and 50 deletions.
1 change: 1 addition & 0 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ target_sources(ginkgo
config/config_helper.cpp
config/property_tree.cpp
config/registry.cpp
config/solver_config.cpp
config/stop_config.cpp
config/type_descriptor.cpp
distributed/index_map.cpp
Expand Down
57 changes: 55 additions & 2 deletions core/config/config_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,21 @@ namespace config {
* LinOpFactoryType enum is to avoid forward declaration, linopfactory header,
* two template versions of parse
*/
enum class LinOpFactoryType : int { Cg = 0 };
enum class LinOpFactoryType : int {
Cg = 0,
Bicg,
Bicgstab,
Fcg,
Cgs,
Ir,
Idr,
Gcr,
Gmres,
CbGmres,
Direct,
LowerTrs,
UpperTrs
};


/**
Expand Down Expand Up @@ -107,13 +121,29 @@ inline std::vector<deferred_factory_parameter<T>> parse_or_get_factory_vector(
}


/**
* get_value gets the corresponding type value from config.
*
* This is specialization for bool type
*/
template <typename ValueType>
inline std::enable_if_t<std::is_same<ValueType, bool>::value, bool> get_value(
const pnode& config)
{
auto val = config.get_boolean();
return val;
}


/**
* get_value gets the corresponding type value from config.
*
* This is specialization for integral type
*/
template <typename IndexType>
inline std::enable_if_t<std::is_integral<IndexType>::value, IndexType>
inline std::enable_if_t<std::is_integral<IndexType>::value &&
!std::is_same<IndexType, bool>::value,
IndexType>
get_value(const pnode& config)
{
auto val = config.get_integer();
Expand Down Expand Up @@ -173,6 +203,29 @@ get_value(const pnode& config)
}


/**
* get_value gets the corresponding type value from config.
*
* This is specialization for initial_guess_mode
*/
template <typename ValueType>
inline std::enable_if_t<
std::is_same<ValueType, solver::initial_guess_mode>::value,
solver::initial_guess_mode>
get_value(const pnode& config)
{
auto val = config.get_string();
if (val == "zero") {
return solver::initial_guess_mode::zero;
} else if (val == "rhs") {
return solver::initial_guess_mode::rhs;
} else if (val == "provided") {
return solver::initial_guess_mode::provided;
}
GKO_INVALID_STATE("Wrong value for initial_guess_mode");
}


} // namespace config
} // namespace gko

Expand Down
2 changes: 2 additions & 0 deletions core/config/dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/solver/solver_base.hpp>
Expand Down Expand Up @@ -106,6 +107,7 @@ deferred_factory_parameter<ReturnType> dispatch(
using value_type_list =
syn::type_list<double, float, std::complex<double>, std::complex<float>>;

using index_type_list = syn::type_list<int32, int64>;

} // namespace config
} // namespace gko
Expand Down
61 changes: 61 additions & 0 deletions core/config/parse_macro.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_CONFIG_PARSE_MACRO_HPP_
#define GKO_CORE_CONFIG_PARSE_MACRO_HPP_


#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/config/type_descriptor.hpp>


#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"
#include "core/config/type_descriptor_helper.hpp"


// for value_type only
#define GKO_PARSE_VALUE_TYPE(_type, _configurator) \
template <> \
deferred_factory_parameter<gko::LinOpFactory> \
parse<gko::config::LinOpFactoryType::_type>( \
const gko::config::pnode& config, \
const gko::config::registry& context, \
const gko::config::type_descriptor& td) \
{ \
auto updated = gko::config::update_type(config, td); \
return gko::config::dispatch<gko::LinOpFactory, _configurator>( \
config, context, updated, \
gko::config::make_type_selector(updated.get_value_typestr(), \
gko::config::value_type_list())); \
} \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")


// for value_type and index_type
#define GKO_PARSE_VALUE_AND_INDEX_TYPE(_type, _configurator) \
template <> \
deferred_factory_parameter<gko::LinOpFactory> \
parse<gko::config::LinOpFactoryType::_type>( \
const gko::config::pnode& config, \
const gko::config::registry& context, \
const gko::config::type_descriptor& td) \
{ \
auto updated = gko::config::update_type(config, td); \
return gko::config::dispatch<gko::LinOpFactory, _configurator>( \
config, context, updated, \
gko::config::make_type_selector(updated.get_value_typestr(), \
gko::config::value_type_list()), \
gko::config::make_type_selector(updated.get_index_typestr(), \
gko::config::index_type_list())); \
} \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")


#endif // GKO_CORE_CONFIG_PARSE_MACRO_HPP_
14 changes: 13 additions & 1 deletion core/config/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,19 @@ namespace config {

configuration_map generate_config_map()
{
return {{"solver::Cg", parse<LinOpFactoryType::Cg>}};
return {{"solver::Cg", parse<LinOpFactoryType::Cg>},
{"solver::Bicg", parse<LinOpFactoryType::Bicg>},
{"solver::Bicgstab", parse<LinOpFactoryType::Bicgstab>},
{"solver::Fcg", parse<LinOpFactoryType::Fcg>},
{"solver::Cgs", parse<LinOpFactoryType::Cgs>},
{"solver::Ir", parse<LinOpFactoryType::Ir>},
{"solver::Idr", parse<LinOpFactoryType::Idr>},
{"solver::Gcr", parse<LinOpFactoryType::Gcr>},
{"solver::Gmres", parse<LinOpFactoryType::Gmres>},
{"solver::CbGmres", parse<LinOpFactoryType::CbGmres>},
{"solver::Direct", parse<LinOpFactoryType::Direct>},
{"solver::LowerTrs", parse<LinOpFactoryType::LowerTrs>},
{"solver::UpperTrs", parse<LinOpFactoryType::UpperTrs>}};
}


Expand Down
48 changes: 48 additions & 0 deletions core/config/solver_config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/solver/bicg.hpp>
#include <ginkgo/core/solver/bicgstab.hpp>
#include <ginkgo/core/solver/cb_gmres.hpp>
#include <ginkgo/core/solver/cg.hpp>
#include <ginkgo/core/solver/cgs.hpp>
#include <ginkgo/core/solver/direct.hpp>
#include <ginkgo/core/solver/fcg.hpp>
#include <ginkgo/core/solver/gcr.hpp>
#include <ginkgo/core/solver/gmres.hpp>
#include <ginkgo/core/solver/idr.hpp>
#include <ginkgo/core/solver/ir.hpp>
#include <ginkgo/core/solver/triangular.hpp>


#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"
#include "core/config/parse_macro.hpp"
#include "core/config/solver_config.hpp"


namespace gko {
namespace config {


GKO_PARSE_VALUE_TYPE(Cg, gko::solver::Cg);
GKO_PARSE_VALUE_TYPE(Bicg, gko::solver::Bicg);
GKO_PARSE_VALUE_TYPE(Bicgstab, gko::solver::Bicgstab);
GKO_PARSE_VALUE_TYPE(Cgs, gko::solver::Cgs);
GKO_PARSE_VALUE_TYPE(Fcg, gko::solver::Fcg);
GKO_PARSE_VALUE_TYPE(Ir, gko::solver::Ir);
GKO_PARSE_VALUE_TYPE(Idr, gko::solver::Idr);
GKO_PARSE_VALUE_TYPE(Gcr, gko::solver::Gcr);
GKO_PARSE_VALUE_TYPE(Gmres, gko::solver::Gmres);
GKO_PARSE_VALUE_TYPE(CbGmres, gko::solver::CbGmres);
GKO_PARSE_VALUE_AND_INDEX_TYPE(Direct, gko::experimental::solver::Direct);
GKO_PARSE_VALUE_AND_INDEX_TYPE(LowerTrs, gko::solver::LowerTrs);
GKO_PARSE_VALUE_AND_INDEX_TYPE(UpperTrs, gko::solver::UpperTrs);


} // namespace config
} // namespace gko
45 changes: 45 additions & 0 deletions core/config/solver_config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_CONFIG_SOLVER_CONFIG_HPP_
#define GKO_CORE_CONFIG_SOLVER_CONFIG_HPP_


#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>


#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"

namespace gko {
namespace config {


template <typename SolverParam>
inline void common_solver_parse(SolverParam& params, const pnode& config,
const registry& context,
type_descriptor td_for_child)
{
if (auto& obj = config.get("generated_preconditioner")) {
params.with_generated_preconditioner(
gko::config::get_stored_obj<const LinOp>(obj, context));
}
if (auto& obj = config.get("criteria")) {
params.with_criteria(
gko::config::parse_or_get_factory_vector<
const stop::CriterionFactory>(obj, context, td_for_child));
}
if (auto& obj = config.get("preconditioner")) {
params.with_preconditioner(
gko::config::parse_or_get_factory<const LinOpFactory>(
obj, context, td_for_child));
}
}


} // namespace config
} // namespace gko

#endif // GKO_CORE_CONFIG_SOLVER_CONFIG_HPP_
12 changes: 6 additions & 6 deletions core/config/stop_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@ namespace config {
deferred_factory_parameter<stop::CriterionFactory> configure_time(
const pnode& config, const registry& context, const type_descriptor& td)
{
auto factory = stop::Time::build();
auto params = stop::Time::build();
if (auto& obj = config.get("time_limit")) {
factory.with_time_limit(gko::config::get_value<long long int>(obj));
params.with_time_limit(gko::config::get_value<long long int>(obj));
}
return factory;
return params;
}


deferred_factory_parameter<stop::CriterionFactory> configure_iter(
const pnode& config, const registry& context, const type_descriptor& td)
{
auto factory = stop::Iteration::build();
auto params = stop::Iteration::build();
if (auto& obj = config.get("max_iters")) {
factory.with_max_iters(gko::config::get_value<size_type>(obj));
params.with_max_iters(gko::config::get_value<size_type>(obj));
}
return factory;
return params;
}


Expand Down
49 changes: 49 additions & 0 deletions core/config/trisolver_config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_CORE_CONFIG_TRISOLVER_CONFIG_HPP_
#define GKO_CORE_CONFIG_TRISOLVER_CONFIG_HPP_


#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/solver/triangular.hpp>


#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"

namespace gko {
namespace config {


template <typename SolverParam>
inline void common_trisolver_parse(SolverParam& params, const pnode& config,
const registry& context,
type_descriptor td_for_child)
{
if (auto& obj = config.get("num_rhs")) {
params.with_num_rhs(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("unit_diagonal")) {
params.with_unit_diagonal(gko::config::get_value<bool>(obj));
}
if (auto& obj = config.get("algorithm")) {
using gko::solver::trisolve_algorithm;
auto str = obj.get_string();
if (str == "sparselib") {
params.with_algorithm(trisolve_algorithm::sparselib);
} else if (str == "syncfree") {
params.with_algorithm(trisolve_algorithm::syncfree);
} else {
GKO_INVALID_STATE("Wrong value for algorithm");
}
}
}


} // namespace config
} // namespace gko

#endif // GKO_CORE_CONFIG_TRISOLVER_CONFIG_HPP_
12 changes: 12 additions & 0 deletions core/solver/bicg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ginkgo/core/base/precision_dispatch.hpp>


#include "core/config/solver_config.hpp"
#include "core/solver/bicg_kernels.hpp"
#include "core/solver/solver_boilerplate.hpp"

Expand All @@ -32,6 +33,17 @@ GKO_REGISTER_OPERATION(step_2, bicg::step_2);
} // namespace bicg


template <typename ValueType>
typename Bicg<ValueType>::parameters_type Bicg<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto params = solver::Bicg<ValueType>::build();
common_solver_parse(params, config, context, td_for_child);
return params;
}


template <typename ValueType>
std::unique_ptr<LinOp> Bicg<ValueType>::transpose() const
{
Expand Down
Loading

0 comments on commit 591ace6

Please sign in to comment.