Skip to content

Commit

Permalink
pass alpha/beta by value to kernel and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Jan 7, 2025
1 parent 1c81c70 commit 77a3551
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 253 deletions.
1 change: 1 addition & 0 deletions common/unified/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ set(UNIFIED_SOURCES
solver/cg_kernels.cpp
solver/cgs_kernels.cpp
solver/cgs_kernels.cpp
solver/chebyshev_kernels.cpp
solver/common_gmres_kernels.cpp
solver/fcg_kernels.cpp
solver/gcr_kernels.cpp
Expand Down
16 changes: 7 additions & 9 deletions common/unified/solver/chebyshev_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace chebyshev {

template <typename ValueType, typename ScalarType>
void init_update(std::shared_ptr<const DefaultExecutor> exec,
const ScalarType* alpha,
const ScalarType alpha,
const matrix::Dense<ValueType>* inner_sol,
matrix::Dense<ValueType>* update_sol,
matrix::Dense<ValueType>* output)
Expand All @@ -27,8 +27,8 @@ void init_update(std::shared_ptr<const DefaultExecutor> exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto inner_sol,
auto update_sol, auto output) {
const auto inner_val = inner_sol(row, col);
update_sol(row, col) = val;
output(row, col) += alpha_val * inner_val;
update_sol(row, col) = inner_val;
output(row, col) += alpha * inner_val;
},
output->get_size(), alpha, inner_sol, update_sol, output);
}
Expand All @@ -38,21 +38,19 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE(


template <typename ValueType, typename ScalarType>
void update(std::shared_ptr<const DefaultExecutor> exec,
const ScalarType* alpha, const ScalarType* beta,
matrix::Dense<ValueType>* inner_sol,
void update(std::shared_ptr<const DefaultExecutor> exec, const ScalarType alpha,
const ScalarType beta, matrix::Dense<ValueType>* inner_sol,
matrix::Dense<ValueType>* update_sol,
matrix::Dense<ValueType>* output)
{
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto beta, auto inner_sol,
auto update_sol, auto output) {
const auto val =
inner_sol(row, col) + beta[0] * update_sol(row, col);
const auto val = inner_sol(row, col) + beta * update_sol(row, col);
inner_sol(row, col) = val;
update_sol(row, col) = val;
output(row, col) += alpha[0] * val;
output(row, col) += alpha * val;
},
output->get_size(), alpha, beta, inner_sol, update_sol, output);
}
Expand Down
67 changes: 4 additions & 63 deletions core/solver/chebyshev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,6 @@ void Chebyshev<ValueType>::apply_with_initial_guess_impl(
}


template <typename Fn>
void visit_criteria(Fn&& fn,
std::shared_ptr<const gko::stop::CriterionFactory> c)
{
fn(c);
if (auto combined =
std::dynamic_pointer_cast<const stop::Combined::Factory>(c)) {
for (const auto& factory : combined->get_parameters().criteria) {
visit_criteria(std::forward<Fn>(fn), factory);
}
}
}


template <typename ValueType>
template <typename VectorType>
void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
Expand All @@ -195,27 +181,6 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
GKO_SOLVER_VECTOR(inner_solution, dense_b);
GKO_SOLVER_VECTOR(update_solution, dense_b);

auto old_num_max_generation = num_max_generation_;
// Use the scalar first
// get the iteration information from stopping criterion.
visit_criteria(
[&](auto factory) {
if (auto iter = std::dynamic_pointer_cast<
const gko::stop::Iteration::Factory>(factory)) {
num_max_generation_ = std::max(
num_max_generation_, iter->get_parameters().max_iters);
}
},
this->get_stop_criterion_factory());
// Regenerate the vector if we realloc the memory.
if (old_num_max_generation != num_max_generation_) {
num_generated_scalar_ = 0;
}
auto alpha = this->template create_workspace_scalar<ValueType>(
GKO_SOLVER_TRAITS::alpha, num_max_generation_ + 1);
auto beta = this->template create_workspace_scalar<ValueType>(
GKO_SOLVER_TRAITS::beta, num_max_generation_ + 1);

GKO_SOLVER_ONE_MINUS_ONE();

auto alpha_ref = ValueType{1} / center_;
Expand Down Expand Up @@ -263,24 +228,11 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
inner_solution->copy_from(residual_ptr);
}
this->get_preconditioner()->apply(residual_ptr, inner_solution);
size_type index =
(iter >= num_max_generation_) ? num_max_generation_ : iter;
auto alpha_scalar =
alpha->create_submatrix(span{0, 1}, span{index, index + 1});
auto beta_scalar =
beta->create_submatrix(span{0, 1}, span{index, index + 1});
if (iter == 0) {
if (num_generated_scalar_ < num_max_generation_) {
alpha_scalar->fill(alpha_ref);
// unused beta for first iteration, but fill zero
beta_scalar->fill(zero<ValueType>());
num_generated_scalar_++;
}
// x = x + alpha * inner_solution
// update_solultion = inner_solution
exec->run(chebyshev::make_init_update(
alpha_scalar->get_const_values(),
gko::detail::get_local(inner_solution),
alpha_ref, gko::detail::get_local(inner_solution),
gko::detail::get_local(update_solution),
gko::detail::get_local(dense_x)));
continue;
Expand All @@ -291,21 +243,11 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
(foci_direction_ * alpha_ref / ValueType{2.0});
}
alpha_ref = ValueType{1.0} / (center_ - beta_ref / alpha_ref);
// The last one is always the updated one
if (num_generated_scalar_ < num_max_generation_ ||
iter >= num_max_generation_) {
alpha_scalar->fill(alpha_ref);
beta_scalar->fill(beta_ref);
}
if (num_generated_scalar_ < num_max_generation_) {
num_generated_scalar_++;
}
// z = z + beta * p
// p = z
// x += alpha * p
exec->run(chebyshev::make_update(
alpha_scalar->get_const_values(), beta_scalar->get_const_values(),
gko::detail::get_local(inner_solution),
alpha_ref, beta_ref, gko::detail::get_local(inner_solution),
gko::detail::get_local(update_solution),
gko::detail::get_local(dense_x)));
}
Expand Down Expand Up @@ -351,7 +293,7 @@ int workspace_traits<Chebyshev<ValueType>>::num_arrays(const Solver&)
template <typename ValueType>
int workspace_traits<Chebyshev<ValueType>>::num_vectors(const Solver&)
{
return 7;
return 5;
}


Expand All @@ -360,8 +302,7 @@ std::vector<std::string> workspace_traits<Chebyshev<ValueType>>::op_names(
const Solver&)
{
return {
"residual", "inner_solution", "update_solution", "alpha", "beta",
"one", "minus_one",
"residual", "inner_solution", "update_solution", "one", "minus_one",
};
}

Expand Down
4 changes: 2 additions & 2 deletions core/solver/chebyshev_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ namespace chebyshev {

#define GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL(ValueType, ScalarType) \
void init_update(std::shared_ptr<const DefaultExecutor> exec, \
const ScalarType* alpha, \
const ScalarType alpha, \
const matrix::Dense<ValueType>* inner_sol, \
matrix::Dense<ValueType>* update_sol, \
matrix::Dense<ValueType>* output)

#define GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL(ValueType, ScalarType) \
void update(std::shared_ptr<const DefaultExecutor> exec, \
const ScalarType* alpha, const ScalarType* beta, \
const ScalarType alpha, const ScalarType beta, \
matrix::Dense<ValueType>* inner_sol, \
matrix::Dense<ValueType>* update_sol, \
matrix::Dense<ValueType>* output)
Expand Down
14 changes: 2 additions & 12 deletions include/ginkgo/core/solver/chebyshev.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,6 @@ class Chebyshev final

private:
std::shared_ptr<const LinOp> solver_{};
// num_generated_scalar_ tracks the number of generated scalar alpha
// and beta.
mutable size_type num_generated_scalar_ = 0;
// num_max_generation_ is the number of generated scalar kept in the
// workspace.
mutable size_type num_max_generation_ = 3;
ValueType center_;
ValueType foci_direction_;
};
Expand Down Expand Up @@ -215,14 +209,10 @@ struct workspace_traits<Chebyshev<ValueType>> {
constexpr static int inner_solution = 1;
// update solution
constexpr static int update_solution = 2;
// alpha
constexpr static int alpha = 3;
// beta
constexpr static int beta = 4;
// constant 1.0 scalar
constexpr static int one = 5;
constexpr static int one = 3;
// constant -1.0 scalar
constexpr static int minus_one = 6;
constexpr static int minus_one = 4;

// stopping status array
constexpr static int stop = 0;
Expand Down
16 changes: 6 additions & 10 deletions reference/solver/chebyshev_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@ namespace chebyshev {

template <typename ValueType, typename ScalarType>
void init_update(std::shared_ptr<const DefaultExecutor> exec,
const ScalarType* alpha,
const ScalarType alpha,
const matrix::Dense<ValueType>* inner_sol,
matrix::Dense<ValueType>* update_sol,
matrix::Dense<ValueType>* output)
{
const auto alpha_val = alpha[0];
for (size_t row = 0; row < output->get_size()[0]; row++) {
for (size_t col = 0; col < output->get_size()[1]; col++) {
const auto inner_val = inner_sol->at(row, col);
update_sol->at(row, col) = inner_val;
output->at(row, col) += alpha_val * inner_val;
output->at(row, col) += alpha * inner_val;
}
}
}
Expand All @@ -34,21 +33,18 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE(


template <typename ValueType, typename ScalarType>
void update(std::shared_ptr<const DefaultExecutor> exec,
const ScalarType* alpha, const ScalarType* beta,
matrix::Dense<ValueType>* inner_sol,
void update(std::shared_ptr<const DefaultExecutor> exec, const ScalarType alpha,
const ScalarType beta, matrix::Dense<ValueType>* inner_sol,
matrix::Dense<ValueType>* update_sol,
matrix::Dense<ValueType>* output)
{
const auto alpha_val = alpha[0];
const auto beta_val = beta[0];
for (size_t row = 0; row < output->get_size()[0]; row++) {
for (size_t col = 0; col < output->get_size()[1]; col++) {
const auto val =
inner_sol->at(row, col) + beta[0] * update_sol->at(row, col);
inner_sol->at(row, col) + beta * update_sol->at(row, col);
inner_sol->at(row, col) = val;
update_sol->at(row, col) = val;
output->at(row, col) += alpha_val * val;
output->at(row, col) += alpha * val;
}
}
}
Expand Down
Loading

0 comments on commit 77a3551

Please sign in to comment.