From 81d78148632a7bd5490e625f175276c9a9c007d5 Mon Sep 17 00:00:00 2001 From: "Yuhsiang M. Tsai" Date: Sun, 1 Oct 2023 20:29:06 +0200 Subject: [PATCH] process two elements for half csr spmv --- .../cuda_hip/matrix/csr_kernels.template.cpp | 51 ++++++++++++++++++- dpcpp/matrix/csr_kernels.dp.cpp | 50 +++++++++++++++++- 2 files changed, 99 insertions(+), 2 deletions(-) diff --git a/common/cuda_hip/matrix/csr_kernels.template.cpp b/common/cuda_hip/matrix/csr_kernels.template.cpp index bd2423d4306..e3f984971c8 100644 --- a/common/cuda_hip/matrix/csr_kernels.template.cpp +++ b/common/cuda_hip/matrix/csr_kernels.template.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -533,6 +533,47 @@ __device__ void device_classical_spmv(const size_type num_rows, } +template +__device__ void device_classical_spmv( + const size_type num_rows, + acc::range> val, + const IndexType* __restrict__ col_idxs, + const IndexType* __restrict__ row_ptrs, acc::range b, + acc::range c, Closure scale) +{ + using arithmetic_type = typename output_accessor::arithmetic_type; + auto subwarp_tile = + group::tiled_partition(group::this_thread_block()); + const auto subrow = thread::get_subwarp_num_flat(); + const auto subid = subwarp_tile.thread_rank() * 2; + // can not use auto for hip because the type is + // __HIP_Coordinates<__HIP_BlockIdx>::__Y which is not allowed in accessor + // operator() + const int column_id = blockIdx.y; + auto row = thread::get_subwarp_id_flat(); + for (; row < num_rows; row += subrow) { + const auto ind_end = row_ptrs[row + 1]; + auto temp_val = zero(); + for (auto ind = row_ptrs[row] + subid; ind < ind_end; + ind += subwarp_size * 2) { + temp_val += val(ind) * b(col_idxs[ind], column_id); + if (ind + 1 < ind_end) { + temp_val += val(ind + 1) * b(col_idxs[ind + 1], column_id); + } + } + auto subwarp_result = + reduce(subwarp_tile, temp_val, + [](const arithmetic_type& a, const arithmetic_type& b) { + return a + b; + }); + if (subid == 0) { + c(row, column_id) = scale(subwarp_result, c(row, column_id)); + } + } +} + + template __global__ __launch_bounds__(spmv_block_size) void abstract_classical_spmv( @@ -2305,6 +2346,10 @@ void spmv(std::shared_ptr exec, max_length_per_row = a->get_num_stored_elements() / std::max(a->get_size()[0], 1); } + if (std::is_same::value) { + // we process two elements in one threads + max_length_per_row /= 2; + } max_length_per_row = std::max(max_length_per_row, 1); host_kernel::select_classical_spmv( classical_kernels(), @@ -2367,6 +2412,10 @@ void advanced_spmv(std::shared_ptr exec, max_length_per_row = a->get_num_stored_elements() / std::max(a->get_size()[0], 1); } + if (std::is_same::value) { + // we process two elements in one threads + max_length_per_row /= 2; + } max_length_per_row = std::max(max_length_per_row, 1); host_kernel::select_classical_spmv( classical_kernels(), diff --git a/dpcpp/matrix/csr_kernels.dp.cpp b/dpcpp/matrix/csr_kernels.dp.cpp index f970d62679b..6208cc93c5b 100644 --- a/dpcpp/matrix/csr_kernels.dp.cpp +++ b/dpcpp/matrix/csr_kernels.dp.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -642,6 +642,46 @@ void device_classical_spmv(const size_type num_rows, } +template +void device_classical_spmv( + const size_type num_rows, + acc::range> val, + const IndexType* __restrict__ col_idxs, + const IndexType* __restrict__ row_ptrs, acc::range b, + acc::range c, Closure scale, sycl::nd_item<3> item_ct1) +{ + using arithmetic_type = typename output_accessor::arithmetic_type; + auto subgroup_tile = group::tiled_partition( + group::this_thread_block(item_ct1)); + const auto subrow = thread::get_subwarp_num_flat(item_ct1); + const auto subid = subgroup_tile.thread_rank() * 2; + const auto column_id = item_ct1.get_group(1); + auto row = thread::get_subwarp_id_flat(item_ct1); + for (; row < num_rows; row += subrow) { + const auto ind_end = row_ptrs[row + 1]; + auto temp_val = zero(); + for (auto ind = row_ptrs[row] + subid; ind < ind_end; + ind += subgroup_size * 2) { + temp_val += val(ind) * b(col_idxs[ind], column_id); + if (ind + 1 < ind_end) { + temp_val += val(ind + 1) * b(col_idxs[ind + 1], column_id); + } + } + auto subgroup_result = ::gko::kernels::dpcpp::reduce( + subgroup_tile, temp_val, + [](const arithmetic_type& a, const arithmetic_type& b) { + return a + b; + }); + // TODO: check the barrier + subgroup_tile.sync(); + if (subid == 0) { + c(row, column_id) = scale(subgroup_result, c(row, column_id)); + } + } +} + + template void abstract_classical_spmv(const size_type num_rows, @@ -1547,6 +1587,10 @@ void spmv(std::shared_ptr exec, max_length_per_row = a->get_num_stored_elements() / std::max(a->get_size()[0], 1); } + if (std::is_same::value) { + // we process two elements in one threads + max_length_per_row /= 2; + } max_length_per_row = std::max(max_length_per_row, 1); host_kernel::select_classical_spmv( classical_kernels(), @@ -1619,6 +1663,10 @@ void advanced_spmv(std::shared_ptr exec, max_length_per_row = a->get_num_stored_elements() / std::max(a->get_size()[0], 1); } + if (std::is_same::value) { + // we process two elements in one threads + max_length_per_row /= 2; + } max_length_per_row = std::max(max_length_per_row, 1); host_kernel::select_classical_spmv( classical_kernels(),