Skip to content

Commit

Permalink
[GPU] Fix gws of Resample onnx kernel (#27990)
Browse files Browse the repository at this point in the history
### Details:
 - *Fix gws value of resample onnx kernel with fs_b_yx_fsv32 format*

### Tickets:
 - *158837*
  • Loading branch information
kelvinchoi-intel authored Dec 10, 2024
1 parent 1be5963 commit 2d78f2a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ DeviceFeaturesKey ResampleKernelOnnx::get_required_device_features_key(const Par
return get_common_subgroups_device_features_key(params);
}

static size_t get_vec_size(const resample_params &params) {
if (params.inputs[0].GetLayout() == DataLayout::fs_b_yx_fsv32) {
return 2;
} else {
return 1;
}
}

ResampleKernelBase::DispatchData ResampleKernelOnnx::SetDefault(const kernel_selector::resample_params& arg) const {
DispatchData dispatchData;
std::vector<std::vector<Tensor::DataChannelName>> dims_by_gws;
Expand All @@ -96,7 +104,7 @@ ResampleKernelBase::DispatchData ResampleKernelOnnx::SetDefault(const kernel_sel
}

dispatchData.gws[0] = CeilDiv(out.X().v, opt_x_block_size) * out.Y().v * out.Z().v;
dispatchData.gws[1] = Align(out.Feature().v, sub_group_size);
dispatchData.gws[1] = Align(CeilDiv(out.Feature().v, get_vec_size(arg)), sub_group_size);
dispatchData.gws[2] = arg.outputs[0].Batch().v;

dispatchData.lws[0] = 1;
Expand Down Expand Up @@ -151,14 +159,9 @@ JitConstants ResampleKernelOnnx::GetJitConstants(const resample_params& params)
jit.AddConstant(MakeJitConstant("X_BLOCKS", CeilDiv(params.outputs[0].X().v, opt_x_block_size)));
jit.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", sub_group_size));

size_t vec_size = 0;
if (params.inputs[0].GetLayout() == DataLayout::fs_b_yx_fsv32) {
vec_size = 2;
jit.AddConstant(MakeJitConstant("FEATURE_SLICE_SIZE", 32));
} else {
vec_size = 1;
jit.AddConstant(MakeJitConstant("FEATURE_SLICE_SIZE", 16));
}
size_t vec_size = get_vec_size(params);
jit.AddConstant(MakeJitConstant("FEATURE_SLICE_SIZE", 16 * vec_size));

if (IsThreeSpatialResample(params))
jit.AddConstant(MakeJitConstant("THREE_SPATIAL_RESAMPLE", ""));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2371,6 +2371,7 @@ INSTANTIATE_TEST_SUITE_P(resample_opt_smoke_linear_onnx_4d_simple,
{ data_types::f16, {1, 128, 13, 13}, {1, 128, 26, 26}, 1, resample::InterpolateOp::InterpolateMode::LINEAR_ONNX, 1, format::bs_fs_yx_bsv32_fsv32, format::bs_fs_yx_bsv32_fsv32, {}, {}},
{ data_types::f16, {1, 128, 13, 13}, {1, 128, 26, 26}, 1, resample::InterpolateOp::InterpolateMode::LINEAR_ONNX, 1, format::bs_fs_yx_bsv16_fsv16, format::bs_fs_yx_bsv16_fsv16, {}, {}},
{ data_types::f16, {1, 128, 13, 13}, {1, 128, 26, 26}, 1, resample::InterpolateOp::InterpolateMode::LINEAR_ONNX, 1, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32, {}, {}},
{ data_types::f16, {2, 32, 14, 14}, {2, 32, 28, 28}, 1, resample::InterpolateOp::InterpolateMode::LINEAR_ONNX, 1, format::fs_b_yx_fsv32, format::fs_b_yx_fsv32, {}, {}},
}
));

Expand Down

0 comments on commit 2d78f2a

Please sign in to comment.