Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] KV-cache compression micro_sdpa kernel #28004

Open
wants to merge 15 commits into from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,7 @@ void prepare_buffer_fusing::run(program& p) {

if (kv_out_layout.is_dynamic()) {
// set dynamic pad dims for shape agnostic kernel
const auto& desc = node.get_primitive();
padding::DynamicDimsMask info_dynamic_pad;
info_dynamic_pad[concat_axis] = 1;
kv_out_layout.data_padding._dynamic_dims_mask = info_dynamic_pad;
Expand All @@ -942,7 +943,7 @@ void prepare_buffer_fusing::run(program& p) {
auto update_scale_zp = [&](size_t kv_cache_output_idx, size_t read_value_output_idx) {
auto scales_out_layout = node.get_output_layout(false, kv_cache_output_idx);

const size_t scales_zp_concat_axis = 2;
const auto scales_zp_concat_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
padding::DynamicDimsMask info_dynamic_pad_scales;
info_dynamic_pad_scales[scales_zp_concat_axis] = 1;
scales_out_layout.data_padding._dynamic_dims_mask = info_dynamic_pad_scales;
Expand All @@ -958,7 +959,6 @@ void prepare_buffer_fusing::run(program& p) {
update_dep(gather_prim, info_dynamic_pad, 0);
}

const auto& desc = node.get_primitive();
if (desc->compressed) {
update_scale_zp(2, 1);

Expand Down
42 changes: 30 additions & 12 deletions src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {

if (desc->get_compression_zp_inputs_num() > 0) {
// Copy zero points to the new buffer if needed
execute_stage(events, instance, res_events, scale_concat_stage, zp_concat_stage);
execute_stage(events, instance, res_events, zp_concat_stage, zp_concat_stage);
}

// Perform dynamic quantization of new token data and append result to the KV-cache
Expand Down Expand Up @@ -417,15 +417,19 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
return params;
}

static kernel_params_t get_compression_scale_update_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
static kernel_params_t get_compression_scale_update_kernel_params(const kernel_impl_params& impl_param,
bool is_scale = true,
bool is_shape_agnostic = false) {
auto params = get_default_params<kernel_selector::concatenation_params>(impl_param, is_shape_agnostic);

const auto concat_axis = 2;
params.axis = convert_axis(concat_axis, impl_param.get_output_layout().get_rank());

auto inputs_count = 1;
auto comp_scale_past_layout = impl_param.input_layouts[3];
auto comp_scale_present_layout = impl_param.output_layouts[2];
const auto inputs_count = 1;
const auto input_idx = is_scale ? 3 : 4; // scale or zp
const auto output_idx = is_scale ? 2 : 3; // scale or zp
auto comp_scale_past_layout = impl_param.input_layouts[input_idx];
auto comp_scale_present_layout = impl_param.output_layouts[output_idx];

params.inputs.resize(inputs_count);
params.inputs[0] = convert_data_tensor(comp_scale_past_layout);
Expand All @@ -435,10 +439,10 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset;

std::map<size_t, size_t> in_tensor_to_offset_map = {
{0, in_offsets_map.at(3)}, // compression_scale_past
{0, in_offsets_map.at(input_idx)}, // compression_[scale/zp]_past
};
std::map<size_t, size_t> out_tensor_to_offset_map = {
{0, out_offsets_map.at(2)}, // compression_scale_present
{0, out_offsets_map.at(output_idx)}, // compression_[scale/zp]_present
};

params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map);
Expand All @@ -451,8 +455,11 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
auto concat_kernel_params = get_concat_kernel_params(impl_param, impl_param.is_dynamic());
auto& concat_kernel_selector = kernel_selector_t::Instance();
kernels_data.push_back(concat_kernel_selector.get_best_kernel(concat_kernel_params));
const bool indirect = impl_param.typed_desc<kv_cache>()->indirect;
const bool compressed = impl_param.typed_desc<kv_cache>()->compressed;

const auto desc = impl_param.typed_desc<kv_cache>();
const bool indirect = desc->indirect;
const bool compressed = desc->compressed;
const bool has_zp_input = desc->get_compression_zp_inputs_num() > 0;
if (indirect) {
auto bt_update_kernel_params = get_bt_update_kernel_params(impl_param, false);
auto& bt_update_kernel_selector = bt_kernel_selector_t::Instance();
Expand All @@ -464,9 +471,14 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
auto& dq_kernel_selector = dq_kernel_selector_t::Instance();
kernels_data.push_back(dq_kernel_selector.get_best_kernel(dq_kernel_params));

auto concat_scale_zp_kernel_params = get_compression_scale_update_kernel_params(impl_param, impl_param.is_dynamic());
auto& concat_scale_zp_kernel_selector = kernel_selector_t::Instance();
kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_scale_zp_kernel_params));
auto concat_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, true, impl_param.is_dynamic());
kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_scale_kernel_params));

if (has_zp_input) {
auto concat_zp_kernel_params = get_compression_scale_update_kernel_params(impl_param, false, impl_param.is_dynamic());
kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_zp_kernel_params));
}
}
return cldnn::make_unique<kv_cache_impl>(kernels_data);
}
Expand Down Expand Up @@ -494,9 +506,15 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
_kernels_data[concat_stage].kernels[1].skip_execution = true;

// Update dynamic quantization parameters
auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, impl_param.is_dynamic());
auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, true, impl_param.is_dynamic());
(_kernels_data[scale_concat_stage].update_dispatch_data_func)(comp_scale_kernel_params, _kernels_data[scale_concat_stage]);
_kernels_data[scale_concat_stage].kernels[0].skip_execution = impl_param._can_be_optimized || impl_param.get_input_layout(3).count() == 0;

if (impl_param.typed_desc<kv_cache>()->get_compression_zp_inputs_num() > 0) {
auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, false, impl_param.is_dynamic());
(_kernels_data[zp_concat_stage].update_dispatch_data_func)(comp_scale_kernel_params, _kernels_data[zp_concat_stage]);
_kernels_data[zp_concat_stage].kernels[0].skip_execution = impl_param._can_be_optimized || impl_param.get_input_layout(4).count() == 0;
}
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
if (has_indirect_inputs(impl_param))
data_inputs_num--;

auto has_zp_input_buffers = false;
auto has_zp_input_buffers = desc->get_compression_zp_inputs_num() > 0;
if (desc->is_kv_compressed) {
data_inputs_num -= 2; // key and value compression scales are handled separately

Expand Down
5 changes: 2 additions & 3 deletions src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ class typed_primitive_inst<kv_cache> : public typed_primitive_inst_base<kv_cache
return sequence_axis >= 0 ? sequence_axis : past_layout_rank + sequence_axis;
}

static int64_t get_scale_zp_sequence_axis() {
// The order of scales and zero points is fixed, so use constant axis
const auto scale_zp_concat_axis = 2;
static int64_t get_scale_zp_sequence_axis(int64_t sequence_axis, const kv_cache::QuantizationAttributes& quantization_attrs) {
const auto scale_zp_concat_axis = quantization_attrs.scales_zp_output_order[sequence_axis];
return scale_zp_concat_axis;
}

Expand Down
28 changes: 16 additions & 12 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,13 @@ void primitive_inst::update_shape() {
_impl_params->state_layouts.resize(compressed_cache_variable->has_zp_state() ? 3 : 2);

auto scales_state = compressed_cache_variable->get_compression_scale_state();
auto new_scales_layout = compressed_cache_variable->get_compression_scale_state()->get_layout();
auto new_scales_layout = scales_state->get_layout();
update_state_layout(*scales_state, new_scales_layout, 1);

if (compressed_cache_variable->has_zp_state()) {
auto scales_state = compressed_cache_variable->get_compression_zp_state();
auto new_zp_layout = compressed_cache_variable->get_compression_zp_state()->get_layout();
update_state_layout(*scales_state, new_zp_layout, 2);
auto zp_state = compressed_cache_variable->get_compression_zp_state();
auto new_zp_layout = zp_state->get_layout();
update_state_layout(*zp_state, new_zp_layout, 2);
}
}
}
Expand Down Expand Up @@ -851,7 +851,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
auto prealloc_shape = updated_layouts[i].get_shape();
const auto shape_rank = prealloc_shape.size();
const auto seq_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, shape_rank)
: kv_cache_inst::get_scale_zp_sequence_axis();
: kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);

prealloc_shape[seq_axis] += tmp_prealloc_count;
required_buffer_size = std::accumulate(prealloc_shape.begin(), prealloc_shape.end(), size_t(1), std::multiplies<size_t>());
Expand Down Expand Up @@ -883,7 +883,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
const auto& desc = _node->as<kv_cache>().get_primitive();
const auto shape_rank = updated_layouts[i].get_shape().size();
const auto seq_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, shape_rank)
: kv_cache_inst::get_scale_zp_sequence_axis();
: kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);

prealloc_info = sp.predict_preallocation_shape(id(), updated_layouts[i], false, i, tmp_prealloc_count, seq_axis);
} else {
Expand All @@ -907,7 +907,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
auto& present_layout = _impl_params->output_layouts[i];
const auto present_layout_rank = present_layout.get_partial_shape().size();
const auto sequence_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, present_layout_rank)
: kv_cache_inst::get_scale_zp_sequence_axis();;
: kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);

auto max_pad = kv_cache_inst::get_max_pad(present_layout,
_max_output_layout_count[i],
Expand Down Expand Up @@ -978,7 +978,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
if (max_pad > 0) {
if (auto compressed_cache_variable = dynamic_cast<ov::intel_gpu::VariableStateIndirectKVCacheCompressed*>(&variable)) {
auto present_scales_layout = _impl_params->output_layouts[2];
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis();;
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);

// In case of compressed KV-cache, calling update_impl for each iteration
// because of scales layout [batch, num_heads, seq_len, head_size], which requires proper
Expand All @@ -990,8 +990,9 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
compressed_cache_variable->get_compression_scale_state()->set_memory(_outputs[2], present_scales_layout);
if (compressed_cache_variable->has_zp_state()) {
auto present_zp_layout = present_scales_layout;
present_zp_layout.data_type = _impl_params->output_layouts[3].data_type;

_impl_params->output_layouts[3] = present_scales_layout;
_impl_params->output_layouts[3] = present_zp_layout;
compressed_cache_variable->get_compression_zp_state()->set_memory(_outputs[3], present_zp_layout);
}
}
Expand Down Expand Up @@ -1373,7 +1374,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
if (desc->compressed) {
auto compressed_cache_variable = dynamic_cast<ov::intel_gpu::VariableStateIndirectKVCacheCompressed*>(&variable);
auto& present_scales_layout = _impl_params->output_layouts[2];
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis();
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
kv_cache_inst::update_pad(present_scales_layout, max_pad - new_seq_len, sequence_axis);
GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id()
<< " Updated present_scale_layout's pad : " << present_scales_layout.to_string() << std::endl;
Expand All @@ -1385,7 +1386,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id()
<< " Updated present_zp_layout's pad : " << present_scales_layout.to_string() << std::endl;

compressed_cache_variable->get_compression_zp_state()->set_layout(present_scales_layout);
compressed_cache_variable->get_compression_zp_state()->set_layout(present_zp_layout);
}
}

Expand All @@ -1397,7 +1398,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {

if (desc->compressed) {
auto& past_scale_layout = _impl_params->input_layouts[3];
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis();
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
kv_cache_inst::update_pad(past_scale_layout, max_pad, sequence_axis);

if (desc->get_compression_zp_inputs_num() > 0) {
Expand Down Expand Up @@ -2104,6 +2105,9 @@ primitive_inst::primitive_inst(network & network, program_node const& node, bool
_outputs = allocate_outputs();
}
}
if (_node) {
GPU_DEBUG_TRACE_DETAIL << _node->type()->to_string(*_node) << "\n";
}
_impls_factory = std::make_shared<ImplementationsFactory>(_node);
_impl_params->strm = _network.get_stream_ptr();
for (size_t i = 0; i < get_node().get_output_layouts().size(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,14 @@ KERNEL(dynamic_quantize_gpu_kv_cache)(
#if ASYMMETRIC_QUANTIZATION
min_value = work_group_reduce_min(min_value);
max_value = work_group_reduce_max(max_value);

// If the range of input data is zero, it is adjusted to the minimum value(0.001).
half diff_value = max_value == min_value ? (grp_max) : (max_value - min_value);
ACCUMULATOR_TYPE diff_value = max_value == min_value ? (grp_max) : (max_value - min_value);
ACCUMULATOR_TYPE scale_tmp = (ACCUMULATOR_TYPE)((CHAR_MAX - CHAR_MIN) / diff_value);
ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_value * scale_tmp) - CHAR_MAX;
ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_value * scale_tmp) + CHAR_MIN;
OUTPUT1_TYPE scale = (OUTPUT1_TYPE)(scale_tmp);
OUTPUT1_TYPE zp = (OUTPUT1_TYPE)(zp_tmp);

#else
max_value = work_group_reduce_max(max_value);
OUTPUT1_TYPE scale = 127.0h / max_value;
Expand Down Expand Up @@ -120,7 +122,13 @@ KERNEL(dynamic_quantize_gpu_kv_cache)(
#if GROUP_SCALES_WITH_ZP
output_scale[scale_idx + 1] = zp;
#else

#if OUTPUT2_IS_FP
output_zp[scale_idx] = zp;
#else
output_zp[scale_idx] = convert_char_rte(zp);
#endif

#endif
#else
output_scale[scale_idx] = 1.0h / scale;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ KERNEL(dynamic_quantize_gpu_ref)(

#if ASYMMETRIC_QUANTIZATION
// If the range of input data is zero, it is adjusted to the minimum value(0.001).
half diff_value = max_val == min_val ? (grp_max) : (max_val - min_val);
ACCUMULATOR_TYPE diff_value = max_val == min_val ? (grp_max) : (max_val - min_val);
ACCUMULATOR_TYPE scale_tmp = (ACCUMULATOR_TYPE)((CHAR_MAX - CHAR_MIN) / diff_value);
# if UNSIGNED_OUTPUT
ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_val * scale_tmp);
# else // !UNSIGNED_OUTPUT
ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_val * scale_tmp) - CHAR_MAX;
ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_val * scale_tmp) + CHAR_MIN;
# endif
OUTPUT1_TYPE scale = (OUTPUT1_TYPE)(scale_tmp);
OUTPUT1_TYPE zp = (OUTPUT1_TYPE)(zp_tmp);
Expand Down Expand Up @@ -161,6 +161,12 @@ KERNEL(dynamic_quantize_gpu_ref)(
#if ASYMMETRIC_QUANTIZATION && GROUP_SCALES_WITH_ZP
output_scale[scale_idx + 1] = zp;
#elif ASYMMETRIC_QUANTIZATION
output_zp[scale_idx] = convert_uchar_rte(zp);
#if OUTPUT2_IS_FP
output_zp[scale_idx] = zp;
#elif UNSIGNED_OUTPUT
output_zp[scale_idx] = convert_uchar_rte(zp);
#else
output_zp[scale_idx] = convert_char_rte(zp);
#endif
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#define KEY_OFF(x0, x1, x2, x3) _4D_OFF(KEY, x0, x1, x2, x3)
#define VAL_OFF(x0, x1, x2, x3) _4D_OFF(VAL, x0, x1, x2, x3)
#define MSK_OFF(x0, x1, x2, x3) _4D_OFF(MSK, x0, x1, x2, x3)
#define KEY_COMP_OFF(x0, x1, x2, x3) _4D_OFF(KEY_COMP, x0, x1, x2, x3)
#define VAL_COMP_OFF(x0, x1, x2, x3) _4D_OFF(VAL_COMP, x0, x1, x2, x3)

#define DST_OFF(x0, x1, d, h, w) \
(((x0) % DST_B0) * DST_SB0 + ((x0) / DST_B0) * DST_S0 \
Expand Down
Loading
Loading