Skip to content

Commit

Permalink
update deferred_factory and format, fix error
Browse files Browse the repository at this point in the history
csr strategy and array input are not supported
  • Loading branch information
yhmtsai committed Nov 30, 2023
1 parent 7c29ddc commit 0b19c10
Show file tree
Hide file tree
Showing 28 changed files with 671 additions and 765 deletions.
136 changes: 27 additions & 109 deletions core/config/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,89 +53,55 @@ inline std::shared_ptr<T> get_pointer(const pnode& config,
{
std::shared_ptr<T> ptr;
using T_non_const = std::remove_const_t<T>;
if (config.is(pnode::status_t::object)) {
if (config.is(pnode::status_t::data)) {
ptr = context.search_data<T_non_const>(config.get_data<std::string>());
} else if (config.is(pnode::status_t::list) &&
std::is_convertible<T_non_const*, LinOpFactory*>::value) {
ptr = std::dynamic_pointer_cast<T_non_const>(
gko::share(build_from_config(config, context, exec, td)));
}
assert(ptr.get() != nullptr);
return std::move(ptr);
}


template <typename T>
inline deferred_factory_parameter<T> get_factory(const pnode& config,
const registry& context,
type_descriptor td);

template <>
inline deferred_factory_parameter<const LinOpFactory>
get_factory<const LinOpFactory>(const pnode& config, const registry& context,
type_descriptor td)
inline std::enable_if_t<!std::is_same<T, const stop::CriterionFactory>::value &&
!std::is_same<T, const LinOpFactory>::value,
deferred_factory_parameter<T>>
get_factory(const pnode& config, const registry& context, type_descriptor td)
{
deferred_factory_parameter<const LinOpFactory> ptr;
deferred_factory_parameter<T> ptr;
if (config.is(pnode::status_t::data)) {
ptr = context.search_data<LinOpFactory>(config.get_data<std::string>());
ptr = context.search_data<std::remove_const_t<T>>(
config.get_data<std::string>());
} else if (config.is(pnode::status_t::map)) {
ptr = build_from_config(config, context, td);
ptr = T::product_type::build_from_config(config, context, td);
}
// handle object is config
assert(!ptr.is_empty());
return std::move(ptr);
}


template <typename Csr>
inline std::shared_ptr<typename Csr::strategy_type> get_strategy(
const pnode& config, const registry& context,
std::shared_ptr<const Executor> exec, type_descriptor td)
template <typename T>
inline std::enable_if_t<std::is_same<T, const LinOpFactory>::value,
deferred_factory_parameter<T>>
get_factory(const pnode& config, const registry& context, type_descriptor td)
{
auto str = config.get_data<std::string>();
std::shared_ptr<typename Csr::strategy_type> strategy_ptr;
if (str == "sparselib" || str == "cusparse") {
strategy_ptr = std::make_shared<typename Csr::sparselib>();
} else if (str == "automatical") {
if (auto explicit_exec =
std::dynamic_pointer_cast<const gko::CudaExecutor>(exec)) {
strategy_ptr =
std::make_shared<typename Csr::automatical>(explicit_exec);
} else if (auto explicit_exec =
std::dynamic_pointer_cast<const gko::HipExecutor>(
exec)) {
strategy_ptr =
std::make_shared<typename Csr::automatical>(explicit_exec);
} else {
strategy_ptr = std::make_shared<typename Csr::automatical>(256);
}
} else if (str == "load_balance") {
if (auto explicit_exec =
std::dynamic_pointer_cast<const gko::CudaExecutor>(exec)) {
strategy_ptr =
std::make_shared<typename Csr::load_balance>(explicit_exec);
} else if (auto explicit_exec =
std::dynamic_pointer_cast<const gko::HipExecutor>(
exec)) {
strategy_ptr =
std::make_shared<typename Csr::load_balance>(explicit_exec);
} else {
strategy_ptr = std::make_shared<typename Csr::load_balance>(256);
}

} else if (str == "merge_path") {
strategy_ptr = std::make_shared<typename Csr::merge_path>();
} else if (str == "classical") {
strategy_ptr = std::make_shared<typename Csr::classical>();
deferred_factory_parameter<T> ptr;
if (config.is(pnode::status_t::data)) {
ptr = context.search_data<std::remove_const_t<T>>(
config.get_data<std::string>());
} else if (config.is(pnode::status_t::map)) {
ptr = build_from_config(config, context, td);
}
return std::move(strategy_ptr);
// handle object is config
assert(!ptr.is_empty());
return std::move(ptr);
}

template <>
deferred_factory_parameter<const stop::CriterionFactory>
get_factory<const stop::CriterionFactory>(const pnode& config,
const registry& context,
type_descriptor td);

template <typename T>
std::enable_if_t<std::is_same<T, const stop::CriterionFactory>::value,
deferred_factory_parameter<T>>
get_factory(const pnode& config, const registry& context, type_descriptor td);


template <typename T>
Expand Down Expand Up @@ -239,31 +205,6 @@ get_value(const pnode& config)
}


template <typename T>
struct is_array_t : std::false_type {};

template <typename V>
struct is_array_t<array<V>> : std::true_type {};

template <typename ArrayType>
inline typename std::enable_if<is_array_t<ArrayType>::value, ArrayType>::type
get_value(const pnode& config, std::shared_ptr<const Executor> exec)
{
using T = typename ArrayType::value_type;
std::vector<T> res;
// for loop in config
if (config.is(pnode::status_t::array)) {
for (const auto& it : config.get_array()) {
res.push_back(get_value<T>(it));
}
} else {
// only one config can be passed without array
res.push_back(get_value<T>(config));
}
return ArrayType(exec, res.begin(), res.end());
}


#define SET_POINTER(_factory, _param_type, _param_name, _config, _context, \
_td) \
{ \
Expand Down Expand Up @@ -301,18 +242,6 @@ get_value(const pnode& config, std::shared_ptr<const Executor> exec)
"This assert is used to counter the false positive extra " \
"semi-colon warnings")

#define SET_CSR_STRATEGY(_factory, _csr_type, _param_name, _config, _context, \
_exec, _td) \
{ \
if (_config.contains(#_param_name)) { \
_factory.with_##_param_name(gko::config::get_strategy<_csr_type>( \
_config.at(#_param_name), _context, _exec, _td)); \
} \
} \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")


#define SET_VALUE(_factory, _param_type, _param_name, _config) \
{ \
Expand All @@ -325,17 +254,6 @@ get_value(const pnode& config, std::shared_ptr<const Executor> exec)
"This assert is used to counter the false positive extra " \
"semi-colon warnings")

#define SET_VALUE_ARRAY(_factory, _param_type, _param_name, _config, _exec) \
{ \
if (_config.contains(#_param_name)) { \
_factory.with_##_param_name(gko::config::get_value<_param_type>( \
_config.at(#_param_name), _exec)); \
} \
} \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")


// If we do not put the build_from_config in the class directly, the following
// can also be in internal header.
Expand Down
Loading

0 comments on commit 0b19c10

Please sign in to comment.