Skip to content

Commit

Permalink
disallow setting indextype and rename var
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed May 27, 2024
1 parent 2ad37ae commit 569148f
Show file tree
Hide file tree
Showing 15 changed files with 67 additions and 69 deletions.
12 changes: 6 additions & 6 deletions core/config/stop_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@ namespace config {
deferred_factory_parameter<stop::CriterionFactory> configure_time(
const pnode& config, const registry& context, const type_descriptor& td)
{
auto factory = stop::Time::build();
auto params = stop::Time::build();
if (auto& obj = config.get("time_limit")) {
factory.with_time_limit(gko::config::get_value<long long int>(obj));
params.with_time_limit(gko::config::get_value<long long int>(obj));
}
return factory;
return params;
}


deferred_factory_parameter<stop::CriterionFactory> configure_iter(
const pnode& config, const registry& context, const type_descriptor& td)
{
auto factory = stop::Iteration::build();
auto params = stop::Iteration::build();
if (auto& obj = config.get("max_iters")) {
factory.with_max_iters(gko::config::get_value<size_type>(obj));
params.with_max_iters(gko::config::get_value<size_type>(obj));
}
return factory;
return params;
}


Expand Down
6 changes: 3 additions & 3 deletions core/solver/bicg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ typename Bicg<ValueType>::parameters_type Bicg<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = solver::Bicg<ValueType>::build();
common_solver_parse(factory, config, context, td_for_child);
return factory;
auto params = solver::Bicg<ValueType>::build();
common_solver_parse(params, config, context, td_for_child);
return params;
}


Expand Down
6 changes: 3 additions & 3 deletions core/solver/bicgstab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ typename Bicgstab<ValueType>::parameters_type Bicgstab<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = solver::Bicgstab<ValueType>::build();
common_solver_parse(factory, config, context, td_for_child);
return factory;
auto params = solver::Bicgstab<ValueType>::build();
common_solver_parse(params, config, context, td_for_child);
return params;
}


Expand Down
10 changes: 5 additions & 5 deletions core/solver/cb_gmres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ typename CbGmres<ValueType>::parameters_type CbGmres<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = solver::CbGmres<ValueType>::build();
common_solver_parse(factory, config, context, td_for_child);
auto params = solver::CbGmres<ValueType>::build();
common_solver_parse(params, config, context, td_for_child);
if (auto& obj = config.get("krylov_dim")) {
factory.with_krylov_dim(gko::config::get_value<size_type>(obj));
params.with_krylov_dim(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("storage_precision")) {
auto get_storage_precision = [](std::string str) {
Expand All @@ -187,9 +187,9 @@ typename CbGmres<ValueType>::parameters_type CbGmres<ValueType>::parse(
}
GKO_INVALID_STATE("Wrong value for storage_precision");
};
factory.with_storage_precision(get_storage_precision(obj.get_string()));
params.with_storage_precision(get_storage_precision(obj.get_string()));
}
return factory;
return params;
}


Expand Down
6 changes: 3 additions & 3 deletions core/solver/cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ typename Cg<ValueType>::parameters_type Cg<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = solver::Cg<ValueType>::build();
common_solver_parse(factory, config, context, td_for_child);
return factory;
auto params = solver::Cg<ValueType>::build();
common_solver_parse(params, config, context, td_for_child);
return params;
}


Expand Down
6 changes: 3 additions & 3 deletions core/solver/cgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ typename Cgs<ValueType>::parameters_type Cgs<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = solver::Cgs<ValueType>::build();
common_solver_parse(factory, config, context, td_for_child);
return factory;
auto params = solver::Cgs<ValueType>::build();
common_solver_parse(params, config, context, td_for_child);
return params;
}


Expand Down
8 changes: 4 additions & 4 deletions core/solver/direct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ Direct<ValueType, IndexType>::parse(const config::pnode& config,
const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = Direct<ValueType, IndexType>::build();
auto params = Direct<ValueType, IndexType>::build();
if (auto& obj = config.get("num_rhs")) {
factory.with_num_rhs(gko::config::get_value<size_type>(obj));
params.with_num_rhs(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("factorization")) {
factory.with_factorization(
params.with_factorization(
gko::config::parse_or_get_factory<const LinOpFactory>(
obj, context, td_for_child));
}
return factory;
return params;
}


Expand Down
6 changes: 3 additions & 3 deletions core/solver/fcg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ typename Fcg<ValueType>::parameters_type Fcg<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = solver::Fcg<ValueType>::build();
common_solver_parse(factory, config, context, td_for_child);
return factory;
auto params = solver::Fcg<ValueType>::build();
common_solver_parse(params, config, context, td_for_child);
return params;
}


Expand Down
8 changes: 4 additions & 4 deletions core/solver/gcr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ typename Gcr<ValueType>::parameters_type Gcr<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = solver::Gcr<ValueType>::build();
common_solver_parse(factory, config, context, td_for_child);
auto params = solver::Gcr<ValueType>::build();
common_solver_parse(params, config, context, td_for_child);
if (auto& obj = config.get("krylov_dim")) {
factory.with_krylov_dim(gko::config::get_value<size_type>(obj));
params.with_krylov_dim(gko::config::get_value<size_type>(obj));
}
return factory;
return params;
}


Expand Down
10 changes: 5 additions & 5 deletions core/solver/gmres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ typename Gmres<ValueType>::parameters_type Gmres<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = solver::Gmres<ValueType>::build();
common_solver_parse(factory, config, context, td_for_child);
auto params = solver::Gmres<ValueType>::build();
common_solver_parse(params, config, context, td_for_child);
if (auto& obj = config.get("krylov_dim")) {
factory.with_krylov_dim(gko::config::get_value<size_type>(obj));
params.with_krylov_dim(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("flexible")) {
factory.with_flexible(gko::config::get_value<bool>(obj));
params.with_flexible(gko::config::get_value<bool>(obj));
}
return factory;
return params;
}


Expand Down
14 changes: 7 additions & 7 deletions core/solver/idr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,22 @@ typename Idr<ValueType>::parameters_type Idr<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = solver::Idr<ValueType>::build();
common_solver_parse(factory, config, context, td_for_child);
auto params = solver::Idr<ValueType>::build();
common_solver_parse(params, config, context, td_for_child);
if (auto& obj = config.get("subspace_dim")) {
factory.with_subspace_dim(gko::config::get_value<size_type>(obj));
params.with_subspace_dim(gko::config::get_value<size_type>(obj));
}
if (auto& obj = config.get("kappa")) {
factory.with_kappa(
params.with_kappa(
gko::config::get_value<remove_complex<ValueType>>(obj));
}
if (auto& obj = config.get("deterministic")) {
factory.with_deterministic(gko::config::get_value<bool>(obj));
params.with_deterministic(gko::config::get_value<bool>(obj));
}
if (auto& obj = config.get("complex_subspace")) {
factory.with_complex_subspace(gko::config::get_value<bool>(obj));
params.with_complex_subspace(gko::config::get_value<bool>(obj));
}
return factory;
return params;
}


Expand Down
14 changes: 7 additions & 7 deletions core/solver/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,29 @@ typename Ir<ValueType>::parameters_type Ir<ValueType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto factory = solver::Ir<ValueType>::build();
auto params = solver::Ir<ValueType>::build();
if (auto& obj = config.get("criteria")) {
factory.with_criteria(
params.with_criteria(
gko::config::parse_or_get_factory_vector<
const stop::CriterionFactory>(obj, context, td_for_child));
}
if (auto& obj = config.get("solver")) {
factory.with_solver(
params.with_solver(
gko::config::parse_or_get_factory<const LinOpFactory>(
obj, context, td_for_child));
}
if (auto& obj = config.get("generated_solver")) {
factory.with_generated_solver(
params.with_generated_solver(
gko::config::get_stored_obj<const LinOp>(obj, context));
}
if (auto& obj = config.get("relaxation_factor")) {
factory.with_relaxation_factor(gko::config::get_value<ValueType>(obj));
params.with_relaxation_factor(gko::config::get_value<ValueType>(obj));
}
if (auto& obj = config.get("default_initial_guess")) {
factory.with_default_initial_guess(
params.with_default_initial_guess(
gko::config::get_value<solver::initial_guess_mode>(obj));
}
return factory;
return params;
}


Expand Down
6 changes: 3 additions & 3 deletions core/solver/lower_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ LowerTrs<ValueType, IndexType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto param = LowerTrs<ValueType, IndexType>::build();
common_trisolver_parse(param, config, context, td_for_child);
return param;
auto params = LowerTrs<ValueType, IndexType>::build();
common_trisolver_parse(params, config, context, td_for_child);
return params;
}


Expand Down
6 changes: 3 additions & 3 deletions core/solver/upper_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ UpperTrs<ValueType, IndexType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto param = UpperTrs<ValueType, IndexType>::build();
common_trisolver_parse(param, config, context, td_for_child);
return param;
auto params = UpperTrs<ValueType, IndexType>::build();
common_trisolver_parse(params, config, context, td_for_child);
return params;
}


Expand Down
18 changes: 8 additions & 10 deletions core/test/config/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ struct CbGmres : SolverConfigTest<gko::solver::CbGmres<float>,


struct Direct
: SolverConfigTest<gko::experimental::solver::Direct<float, gko::int64>,
: SolverConfigTest<gko::experimental::solver::Direct<float, int>,
gko::experimental::solver::Direct<double, int>> {
static pnode::map_type setup_base()
{
Expand All @@ -351,7 +351,6 @@ struct Direct
static void change_template(pnode::map_type& config_map)
{
config_map["value_type"] = pnode{"float32"};
config_map["index_type"] = pnode{"int64"};
}

template <bool from_reg, typename ParamType>
Expand Down Expand Up @@ -393,11 +392,10 @@ struct Direct


template <template <class, class> class Trs>
struct TrsHelper : SolverConfigTest<Trs<float, gko::int64>, Trs<double, int>> {
struct TrsHelper : SolverConfigTest<Trs<float, int>, Trs<double, int>> {
static void change_template(pnode::map_type& config_map)
{
config_map["value_type"] = pnode{"float32"};
config_map["index_type"] = pnode{"int64"};
}

template <bool from_reg, typename ParamType>
Expand Down Expand Up @@ -507,12 +505,12 @@ TYPED_TEST(Solver, SetFromRegistry)
using Config = typename TestFixture::Config;
auto config_map = Config::setup_base();
Config::change_template(config_map);
auto param = Config::changed_solver_type::build();
Config::template set<true>(config_map, param, this->reg, this->exec);
auto params = Config::changed_solver_type::build();
Config::template set<true>(config_map, params, this->reg, this->exec);
auto config = pnode(config_map);

auto res = parse(config, this->reg, this->td).on(this->exec);
auto ans = param.on(this->exec);
auto ans = params.on(this->exec);

Config::template validate<true>(res.get(), ans.get());
}
Expand All @@ -523,12 +521,12 @@ TYPED_TEST(Solver, SetFromConfig)
using Config = typename TestFixture::Config;
auto config_map = Config::setup_base();
Config::change_template(config_map);
auto param = Config::changed_solver_type::build();
Config::template set<false>(config_map, param, this->reg, this->exec);
auto params = Config::changed_solver_type::build();
Config::template set<false>(config_map, params, this->reg, this->exec);
auto config = pnode(config_map);

auto res = parse(config, this->reg, this->td).on(this->exec);
auto ans = param.on(this->exec);
auto ans = params.on(this->exec);

Config::template validate<false>(res.get(), ans.get());
}

0 comments on commit 569148f

Please sign in to comment.