Skip to content

Commit

Permalink
adapt with the explicit deferred type
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Nov 27, 2023
1 parent d7a74ba commit 9ccf9b9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
21 changes: 11 additions & 10 deletions core/config/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,16 @@ inline std::shared_ptr<T> get_pointer(const pnode& config,


template <typename T>
inline deferred_factory_parameter<std::remove_const_t<T>> get_factory(
const pnode& config, const registry& context, type_descriptor td);
inline deferred_factory_parameter<T> get_factory(const pnode& config,
const registry& context,
type_descriptor td);

template <>
inline deferred_factory_parameter<LinOpFactory> get_factory<const LinOpFactory>(
const pnode& config, const registry& context, type_descriptor td)
inline deferred_factory_parameter<const LinOpFactory>
get_factory<const LinOpFactory>(const pnode& config, const registry& context,
type_descriptor td)
{
deferred_factory_parameter<LinOpFactory> ptr;
deferred_factory_parameter<const LinOpFactory> ptr;
if (config.is(pnode::status_t::data)) {
ptr = context.search_data<LinOpFactory>(config.get_data<std::string>());
} else if (config.is(pnode::status_t::map)) {
Expand All @@ -105,18 +107,17 @@ inline deferred_factory_parameter<LinOpFactory> get_factory<const LinOpFactory>(
}

template <>
deferred_factory_parameter<stop::CriterionFactory>
deferred_factory_parameter<const stop::CriterionFactory>
get_factory<const stop::CriterionFactory>(const pnode& config,
const registry& context,
type_descriptor td);


template <typename T>
inline std::vector<deferred_factory_parameter<std::remove_const_t<T>>>
get_factory_vector(const pnode& config, const registry& context,
type_descriptor td)
inline std::vector<deferred_factory_parameter<T>> get_factory_vector(
const pnode& config, const registry& context, type_descriptor td)
{
std::vector<deferred_factory_parameter<std::remove_const_t<T>>> res;
std::vector<deferred_factory_parameter<T>> res;
// for loop in config
if (config.is(pnode::status_t::array)) {
for (const auto& it : config.get_array()) {
Expand Down
4 changes: 2 additions & 2 deletions core/config/stop_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ configure_implicit_residual(const pnode& config, const registry& context,


template <>
deferred_factory_parameter<stop::CriterionFactory>
deferred_factory_parameter<const stop::CriterionFactory>
get_factory<const stop::CriterionFactory>(const pnode& config,
const registry& context,
type_descriptor td)
{
deferred_factory_parameter<stop::CriterionFactory> ptr;
deferred_factory_parameter<const stop::CriterionFactory> ptr;
if (config.is(pnode::status_t::data)) {
return context.search_data<stop::CriterionFactory>(
config.get_data<std::string>());
Expand Down
21 changes: 10 additions & 11 deletions core/test/config/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ TEST_F(Config, GenerateObjectWithData)
{"criteria", this->stop_config}}};
auto obj = build_from_config<0>(p, reg, {"float", ""}).on(this->exec);

ASSERT_NE(dynamic_cast<const gko::solver::Cg<float>::Factory*>(obj.get()),
ASSERT_NE(dynamic_cast<gko::solver::Cg<float>::Factory*>(obj.get()),
nullptr);
ASSERT_NE(dynamic_cast<const gko::solver::Cg<float>::Factory*>(obj.get())
ASSERT_NE(dynamic_cast<gko::solver::Cg<float>::Factory*>(obj.get())
->get_parameters()
.generated_preconditioner,
nullptr);
Expand All @@ -121,9 +121,9 @@ TEST_F(Config, GenerateObjectWithPreconditioner)
pnode{{{"Type", pnode{"Cg"}}, {"criteria", this->stop_config}}};
auto obj = build_from_config<0>(p, reg).on(this->exec);

ASSERT_NE(dynamic_cast<const gko::solver::Cg<double>::Factory*>(obj.get()),
ASSERT_NE(dynamic_cast<gko::solver::Cg<double>::Factory*>(obj.get()),
nullptr);
ASSERT_NE(dynamic_cast<const gko::solver::Cg<double>::Factory*>(obj.get())
ASSERT_NE(dynamic_cast<gko::solver::Cg<double>::Factory*>(obj.get())
->get_parameters()
.preconditioner,
nullptr);
Expand All @@ -146,14 +146,13 @@ TEST_F(Config, GenerateObjectWithCustomBuild)
pnode{std::map<std::string, pnode>{{"Type", pnode{"Custom"}}}};
auto obj = build_from_config<0>(p, reg, {"double", ""}).on(this->exec);

ASSERT_NE(dynamic_cast<const gko::solver::Cg<double>::Factory*>(obj.get()),
ASSERT_NE(dynamic_cast<gko::solver::Cg<double>::Factory*>(obj.get()),
nullptr);
ASSERT_NE(dynamic_cast<const gko::solver::Bicg<double>::Factory*>(
dynamic_cast<gko::solver::Cg<double>::Factory*>(obj.get())
->get_parameters()
.preconditioner.get()),
nullptr);
ASSERT_NE(
dynamic_cast<const gko::solver::Bicg<double>::Factory*>(
dynamic_cast<const gko::solver::Cg<double>::Factory*>(obj.get())
->get_parameters()
.preconditioner.get()),
nullptr);
}


Expand Down

0 comments on commit 9ccf9b9

Please sign in to comment.