Skip to content

Commit

Permalink
process two elements for half csr spmv
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Jan 9, 2025
1 parent 3de8461 commit 81d7814
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 2 deletions.
51 changes: 50 additions & 1 deletion common/cuda_hip/matrix/csr_kernels.template.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -533,6 +533,47 @@ __device__ void device_classical_spmv(const size_type num_rows,
}


template <size_type subwarp_size, typename AccessType, typename input_accessor,
typename output_accessor, typename IndexType, typename Closure>
__device__ void device_classical_spmv(
const size_type num_rows,
acc::range<acc::reduced_row_major<1u, AccessType, const __half>> val,
const IndexType* __restrict__ col_idxs,
const IndexType* __restrict__ row_ptrs, acc::range<input_accessor> b,
acc::range<output_accessor> c, Closure scale)
{
using arithmetic_type = typename output_accessor::arithmetic_type;
auto subwarp_tile =
group::tiled_partition<subwarp_size>(group::this_thread_block());
const auto subrow = thread::get_subwarp_num_flat<subwarp_size>();
const auto subid = subwarp_tile.thread_rank() * 2;
// can not use auto for hip because the type is
// __HIP_Coordinates<__HIP_BlockIdx>::__Y which is not allowed in accessor
// operator()
const int column_id = blockIdx.y;
auto row = thread::get_subwarp_id_flat<subwarp_size>();
for (; row < num_rows; row += subrow) {
const auto ind_end = row_ptrs[row + 1];
auto temp_val = zero<arithmetic_type>();
for (auto ind = row_ptrs[row] + subid; ind < ind_end;
ind += subwarp_size * 2) {
temp_val += val(ind) * b(col_idxs[ind], column_id);
if (ind + 1 < ind_end) {
temp_val += val(ind + 1) * b(col_idxs[ind + 1], column_id);
}
}
auto subwarp_result =
reduce(subwarp_tile, temp_val,
[](const arithmetic_type& a, const arithmetic_type& b) {
return a + b;
});
if (subid == 0) {
c(row, column_id) = scale(subwarp_result, c(row, column_id));
}
}
}


template <size_type subwarp_size, typename matrix_accessor,
typename input_accessor, typename output_accessor, typename IndexType>
__global__ __launch_bounds__(spmv_block_size) void abstract_classical_spmv(
Expand Down Expand Up @@ -2305,6 +2346,10 @@ void spmv(std::shared_ptr<const DefaultExecutor> exec,
max_length_per_row = a->get_num_stored_elements() /
std::max<size_type>(a->get_size()[0], 1);
}
if (std::is_same<MatrixValueType, gko::half>::value) {
// we process two elements in one threads
max_length_per_row /= 2;
}
max_length_per_row = std::max<size_type>(max_length_per_row, 1);
host_kernel::select_classical_spmv(
classical_kernels(),
Expand Down Expand Up @@ -2367,6 +2412,10 @@ void advanced_spmv(std::shared_ptr<const DefaultExecutor> exec,
max_length_per_row = a->get_num_stored_elements() /
std::max<size_type>(a->get_size()[0], 1);
}
if (std::is_same<MatrixValueType, gko::half>::value) {
// we process two elements in one threads
max_length_per_row /= 2;
}
max_length_per_row = std::max<size_type>(max_length_per_row, 1);
host_kernel::select_classical_spmv(
classical_kernels(),
Expand Down
50 changes: 49 additions & 1 deletion dpcpp/matrix/csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -642,6 +642,46 @@ void device_classical_spmv(const size_type num_rows,
}


template <size_type subgroup_size, typename AccessType, typename input_accessor,
typename output_accessor, typename IndexType, typename Closure>
void device_classical_spmv(
const size_type num_rows,
acc::range<acc::reduced_row_major<1u, AccessType, const __half>> val,
const IndexType* __restrict__ col_idxs,
const IndexType* __restrict__ row_ptrs, acc::range<input_accessor> b,
acc::range<output_accessor> c, Closure scale, sycl::nd_item<3> item_ct1)
{
using arithmetic_type = typename output_accessor::arithmetic_type;
auto subgroup_tile = group::tiled_partition<subgroup_size>(
group::this_thread_block(item_ct1));
const auto subrow = thread::get_subwarp_num_flat<subgroup_size>(item_ct1);
const auto subid = subgroup_tile.thread_rank() * 2;
const auto column_id = item_ct1.get_group(1);
auto row = thread::get_subwarp_id_flat<subgroup_size>(item_ct1);
for (; row < num_rows; row += subrow) {
const auto ind_end = row_ptrs[row + 1];
auto temp_val = zero<arithmetic_type>();
for (auto ind = row_ptrs[row] + subid; ind < ind_end;
ind += subgroup_size * 2) {
temp_val += val(ind) * b(col_idxs[ind], column_id);
if (ind + 1 < ind_end) {
temp_val += val(ind + 1) * b(col_idxs[ind + 1], column_id);
}
}
auto subgroup_result = ::gko::kernels::dpcpp::reduce(
subgroup_tile, temp_val,
[](const arithmetic_type& a, const arithmetic_type& b) {
return a + b;
});
// TODO: check the barrier
subgroup_tile.sync();
if (subid == 0) {
c(row, column_id) = scale(subgroup_result, c(row, column_id));
}
}
}


template <size_type subgroup_size, typename matrix_accessor,
typename input_accessor, typename output_accessor, typename IndexType>
void abstract_classical_spmv(const size_type num_rows,
Expand Down Expand Up @@ -1547,6 +1587,10 @@ void spmv(std::shared_ptr<const DpcppExecutor> exec,
max_length_per_row = a->get_num_stored_elements() /
std::max<size_type>(a->get_size()[0], 1);
}
if (std::is_same<MatrixValueType, gko::half>::value) {
// we process two elements in one threads
max_length_per_row /= 2;
}
max_length_per_row = std::max<size_type>(max_length_per_row, 1);
host_kernel::select_classical_spmv(
classical_kernels(),
Expand Down Expand Up @@ -1619,6 +1663,10 @@ void advanced_spmv(std::shared_ptr<const DpcppExecutor> exec,
max_length_per_row = a->get_num_stored_elements() /
std::max<size_type>(a->get_size()[0], 1);
}
if (std::is_same<MatrixValueType, gko::half>::value) {
// we process two elements in one threads
max_length_per_row /= 2;
}
max_length_per_row = std::max<size_type>(max_length_per_row, 1);
host_kernel::select_classical_spmv(
classical_kernels(),
Expand Down

0 comments on commit 81d7814

Please sign in to comment.