Skip to content

Commit

Permalink
[GPU] Fix crash on swiglu fused case (due to outer_ofm == 1) (#27972)
Browse files Browse the repository at this point in the history
### Details:
 - fixed crash happens in minicpm-1b-sft int4 model 

### Tickets:
 - *ticket-id*
  • Loading branch information
yeonbok authored Dec 9, 2024
1 parent b840082 commit a3f4edb
Showing 1 changed file with 10 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,14 @@ FullyConnected_bf_tiled::GetAutoTuneParams(const fully_connected_params& params,
return selector.Default(tune_params(1, 1, 4, 4, 1, 1, 1, EXE_MODE_DEFAULT));
}
} else if (is_weight_small_kn(params, output_f)) {
if (params.weights.GetLayout() == WeightsLayout::os_is_yx_osv32_isv2)
return selector.Default(tune_params(1, 1, 4, 2, 1, 1, 1, EXE_MODE_DEFAULT));
else
if (params.weights.GetLayout() == WeightsLayout::os_is_yx_osv32_isv2) {
if (swiglu_fused)
return selector.Default(tune_params(1, 1, 4, 2, 2, 1, 1, EXE_MODE_DEFAULT));
else
return selector.Default(tune_params(1, 1, 4, 2, 1, 1, 1, EXE_MODE_DEFAULT));
} else {
return selector.Default(tune_params(1, 2, 4, 2, 1, 1, 1, EXE_MODE_DEFAULT));
}
} else {
if (params.weights.GetLayout() == WeightsLayout::os_iyx_osv16) {
return selector.Default(tune_params(1, 1, 4, 4, 1, 1, 1, EXE_MODE_DEFAULT));
Expand Down Expand Up @@ -865,7 +869,9 @@ KernelsData FullyConnected_bf_tiled::GetTunedKernelsDataByIndex(const Params &pa
auto output_f = get_output_aligned_bf_size(fc_params, false).second;

WeightsLayout weights_layout = WeightsLayout::os_iyx_osv16;
if (!is_swiglu_fused(fc_params) && fc_params.compressed && fc_params.inputs[0].GetDType() == Datatype::F16
if (is_swiglu_fused(fc_params)) {
weights_layout = WeightsLayout::os_is_yx_osv32_isv2;
} else if (fc_params.compressed && fc_params.inputs[0].GetDType() == Datatype::F16
&& (fc_params.weights.GetLayout() == WeightsLayout::oiyx || fc_params.weights.GetLayout() == WeightsLayout::os_is_yx_osv64_isv2)
&& (fc_params.weights.GetDType() == WeightsType::INT4 || fc_params.weights.GetDType() == WeightsType::UINT4)
&& is_weight_horizontal(fc_params, output_f)) {
Expand Down

0 comments on commit a3f4edb

Please sign in to comment.