Skip to content

Commit

Permalink
Review updates
Browse files Browse the repository at this point in the history
Co-authored-by: Yu-Hsiang Tsai <[email protected]>
  • Loading branch information
pratikvn and yhmtsai committed Nov 5, 2023
1 parent b8babda commit f48179b
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 117 deletions.
39 changes: 20 additions & 19 deletions dpcpp/solver/batch_bicgstab_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class KernelCaller {
{}

template <typename StopType, const int subgroup_size,
const int n_shared_total, const bool sg_kernel_all,
typename PrecType, typename LogType, typename BatchMatrixType>
const int n_shared_total, typename PrecType, typename LogType,
typename BatchMatrixType>
__dpct_inline__ void launch_apply_kernel(
const gko::kernels::batch_bicgstab::storage_config& sconf,
LogType& logger, PrecType& prec, const BatchMatrixType mat,
Expand All @@ -118,9 +118,10 @@ class KernelCaller {
slm_values(sycl::range<1>(shared_size), cgh);

cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(
subgroup_size)]] [[intel::kernel_args_restrict]] {
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(
subgroup_size)]] [
[intel::kernel_args_restrict]] {
auto batch_id = item_ct1.get_group_linear_id();
const auto mat_global_entry =
gko::batch::matrix::extract_batch_item(mat, batch_id);
Expand All @@ -130,7 +131,7 @@ class KernelCaller {
ValueType* const x_global_entry =
gko::batch::multi_vector::batch_item_ptr(
x_values, 1, num_rows, batch_id);
apply_kernel<StopType, n_shared_total, sg_kernel_all>(
apply_kernel<StopType, n_shared_total>(
sconf, max_iters, res_tol, logger, prec,
mat_global_entry, b_global_entry, x_global_entry,
num_rows, mat.get_single_item_num_nnz(),
Expand Down Expand Up @@ -197,67 +198,67 @@ class KernelCaller {
// launch_apply_kernel<StopType, subgroup_size, n_shared_total,
// sg_kernel_all>
if (num_rows <= 32 && n_shared_total == 10) {
launch_apply_kernel<StopType, 32, 10, true>(
launch_apply_kernel<StopType, 32, 10>(
sconf, logger, prec, mat, b.values, x.values, workspace_data,
group_size, shared_size);
} else if (num_rows <= 256 && n_shared_total == 10) {
launch_apply_kernel<StopType, 32, 10, true>(
launch_apply_kernel<StopType, 32, 10>(
sconf, logger, prec, mat, b.values, x.values, workspace_data,
group_size, shared_size);
} else {
switch (n_shared_total) {
case 0:
launch_apply_kernel<StopType, 32, 0, true>(
launch_apply_kernel<StopType, 32, 0>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 1:
launch_apply_kernel<StopType, 32, 1, true>(
launch_apply_kernel<StopType, 32, 1>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 2:
launch_apply_kernel<StopType, 32, 2, true>(
launch_apply_kernel<StopType, 32, 2>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 3:
launch_apply_kernel<StopType, 32, 3, true>(
launch_apply_kernel<StopType, 32, 3>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 4:
launch_apply_kernel<StopType, 32, 4, true>(
launch_apply_kernel<StopType, 32, 4>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 5:
launch_apply_kernel<StopType, 32, 5, true>(
launch_apply_kernel<StopType, 32, 5>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 6:
launch_apply_kernel<StopType, 32, 6, true>(
launch_apply_kernel<StopType, 32, 6>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 7:
launch_apply_kernel<StopType, 32, 7, true>(
launch_apply_kernel<StopType, 32, 7>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 8:
launch_apply_kernel<StopType, 32, 8, true>(
launch_apply_kernel<StopType, 32, 8>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 9:
launch_apply_kernel<StopType, 32, 9, true>(
launch_apply_kernel<StopType, 32, 9>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
case 10:
launch_apply_kernel<StopType, 32, 10, true>(
launch_apply_kernel<StopType, 32, 10>(
sconf, logger, prec, mat, b.values, x.values,
workspace_data, group_size, shared_size);
break;
Expand Down
140 changes: 50 additions & 90 deletions dpcpp/solver/batch_bicgstab_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************<GINKGO LICENSE>*******************************/

template <const bool sg_kernel_all, typename BatchMatrixType_entry,
typename ValueType>
template <typename BatchMatrixType_entry, typename ValueType>
__dpct_inline__ void initialize(
const int num_rows, const BatchMatrixType_entry& mat_global_entry,
const ValueType* const b_global_entry,
Expand Down Expand Up @@ -68,17 +67,12 @@ __dpct_inline__ void initialize(
r_shared_entry, item_ct1);
item_ct1.barrier(sycl::access::fence_space::global_and_local);

if constexpr (sg_kernel_all) {
if (sg_id == 0) {
single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm,
item_ct1);
} else if (sg_id == 1) {
single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm,
item_ct1);
}
} else {
single_rhs_compute_norm2(num_rows, r_shared_entry, res_norm, item_ct1);
single_rhs_compute_norm2(num_rows, b_global_entry, rhs_norm, item_ct1);
if (sg_id == 0) {
single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm,
item_ct1);
} else if (sg_id == 1) {
single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm,
item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);

Expand Down Expand Up @@ -111,7 +105,7 @@ __dpct_inline__ void update_p(const int num_rows, const ValueType& rho_new,
}


template <const bool sg_kernel_all, typename ValueType>
template <typename ValueType>
__dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new,
const ValueType* const r_hat_shared_entry,
const ValueType* const v_shared_entry,
Expand All @@ -120,23 +114,15 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new,
auto sg = item_ct1.get_sub_group();
const auto sg_id = sg.get_group_id();
const auto tid = item_ct1.get_local_linear_id();
if constexpr (sg_kernel_all) {
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry,
v_shared_entry, alpha, item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
if (tid == 0) {
alpha = rho_new / alpha;
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
} else {
single_rhs_compute_conj_dot(num_rows, r_hat_shared_entry,
v_shared_entry, alpha, item_ct1);
if (tid == 0) {
alpha = rho_new / alpha;
}
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry,
v_shared_entry, alpha, item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
if (tid == 0) {
alpha = rho_new / alpha;
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
}


Expand All @@ -155,7 +141,7 @@ __dpct_inline__ void update_s(const int num_rows,
}


template <const bool sg_kernel_all, typename ValueType>
template <typename ValueType>
__dpct_inline__ void compute_omega(const int num_rows,
const ValueType* const t_shared_entry,
const ValueType* const s_shared_entry,
Expand All @@ -165,28 +151,18 @@ __dpct_inline__ void compute_omega(const int num_rows,
auto sg = item_ct1.get_sub_group();
const auto sg_id = sg.get_group_id();
const auto tid = item_ct1.get_local_linear_id();
if constexpr (sg_kernel_all) {
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry,
s_shared_entry, omega, item_ct1);
} else if (sg_id == 1) {
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry,
t_shared_entry, temp, item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
if (tid == 0) {
omega /= temp;
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
} else {
single_rhs_compute_conj_dot(num_rows, t_shared_entry, s_shared_entry,
omega, item_ct1);
single_rhs_compute_conj_dot(num_rows, t_shared_entry, t_shared_entry,
temp, item_ct1);
if (tid == 0) {
omega /= temp;
}
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, s_shared_entry,
omega, item_ct1);
} else if (sg_id == 1) {
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, t_shared_entry,
temp, item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
if (tid == 0) {
omega /= temp;
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
}


Expand Down Expand Up @@ -220,9 +196,8 @@ __dpct_inline__ void update_x_middle(const int num_rows, const ValueType& alpha,
}


template <typename StopType, const int n_shared_total, const bool sg_kernel_all,
typename PrecType, typename LogType, typename BatchMatrixType,
typename ValueType>
template <typename StopType, const int n_shared_total, typename PrecType,
typename LogType, typename BatchMatrixType, typename ValueType>
void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
const int max_iter, const gko::remove_complex<ValueType> tol,
LogType logger, PrecType prec_shared,
Expand Down Expand Up @@ -344,10 +319,10 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
// p = 0
// p_hat = 0
// v = 0
initialize<sg_kernel_all>(num_rows, mat_global_entry, b_global_entry,
x_global_entry, rho_old_sh[0], omega_sh[0],
alpha_sh[0], x_sh, r_sh, r_hat_sh, p_sh, p_hat_sh,
v_sh, norms_rhs_sh[0], norms_res_sh[0], item_ct1);
initialize(num_rows, mat_global_entry, b_global_entry, x_global_entry,
rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh, r_hat_sh,
p_sh, p_hat_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0],
item_ct1);
item_ct1.barrier(sycl::access::fence_space::global_and_local);

// stopping criterion object
Expand All @@ -361,16 +336,11 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
}

// rho_new = < r_hat , r > = (r_hat)' * (r)
if constexpr (sg_kernel_all) {
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh,
rho_new_sh[0], item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
} else {
single_rhs_compute_conj_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0],
item_ct1);
if (sg_id == 0) {
single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh,
rho_new_sh[0], item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);

// beta = (rho_new / rho_old)*(alpha / omega)
// p = r + beta*(p - omega * v)
Expand All @@ -387,24 +357,20 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
item_ct1.barrier(sycl::access::fence_space::global_and_local);

// alpha = rho_new / < r_hat , v>
compute_alpha<sg_kernel_all>(num_rows, rho_new_sh[0], r_hat_sh, v_sh,
alpha_sh[0], item_ct1);
compute_alpha(num_rows, rho_new_sh[0], r_hat_sh, v_sh, alpha_sh[0],
item_ct1);
item_ct1.barrier(sycl::access::fence_space::global_and_local);

// s = r - alpha*v
update_s(num_rows, r_sh, alpha_sh[0], v_sh, s_sh, item_ct1);
item_ct1.barrier(sycl::access::fence_space::global_and_local);

// an estimate of residual norms
if constexpr (sg_kernel_all) {
if (sg_id == 0) {
single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0],
item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
} else {
single_rhs_compute_norm2(num_rows, s_sh, norms_res_sh[0], item_ct1);
if (sg_id == 0) {
single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0],
item_ct1);
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);

if (stop.check_converged(norms_res_sh)) {
update_x_middle(num_rows, alpha_sh[0], p_hat_sh, x_sh, item_ct1);
Expand All @@ -421,8 +387,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
item_ct1.barrier(sycl::access::fence_space::global_and_local);

// omega = <t,s> / <t,t>
compute_omega<sg_kernel_all>(num_rows, t_sh, s_sh, temp_sh[0],
omega_sh[0], item_ct1);
compute_omega(num_rows, t_sh, s_sh, temp_sh[0], omega_sh[0], item_ct1);
item_ct1.barrier(sycl::access::fence_space::global_and_local);

// x = x + alpha*p_hat + omega *s_hat
Expand All @@ -431,18 +396,13 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
s_sh, t_sh, x_sh, r_sh, item_ct1);
item_ct1.barrier(sycl::access::fence_space::global_and_local);

if constexpr (sg_kernel_all) {
if (sg_id == 0)
single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0],
item_ct1);
if (tid == group_size - 1) {
rho_old_sh[0] = rho_new_sh[0];
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
} else {
single_rhs_compute_norm2(num_rows, r_sh, norms_res_sh[0], item_ct1);
if (sg_id == 0)
single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0],
item_ct1);
if (tid == group_size - 1) {
rho_old_sh[0] = rho_new_sh[0];
}
item_ct1.barrier(sycl::access::fence_space::global_and_local);
}

logger.log_iteration(batch_id, iter, norms_res_sh[0]);
Expand Down
8 changes: 0 additions & 8 deletions include/ginkgo/core/solver/batch_solver_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,6 @@ class BatchSolver {
};


/**
* The parameter type shared between all preconditioned iterative solvers,
* excluding the parameters available in iterative_solver_factory_parameters.
* @see GKO_CREATE_FACTORY_PARAMETERS
*/
struct preconditioned_iterative_solver_factory_parameters {};


template <typename Parameters, typename Factory>
struct enable_preconditioned_iterative_solver_factory_parameters
: enable_parameters_type<Parameters, Factory> {
Expand Down

0 comments on commit f48179b

Please sign in to comment.