Skip to content

Commit

Permalink
WIP: [GPU] KV-cache compression micro_sdpa kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Dec 10, 2024
1 parent 67f2537 commit 5df2d53
Show file tree
Hide file tree
Showing 17 changed files with 434 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,7 @@ void prepare_buffer_fusing::run(program& p) {

if (kv_out_layout.is_dynamic()) {
// set dynamic pad dims for shape agnostic kernel
const auto& desc = node.get_primitive();
padding::DynamicDimsMask info_dynamic_pad;
info_dynamic_pad[concat_axis] = 1;
kv_out_layout.data_padding._dynamic_dims_mask = info_dynamic_pad;
Expand All @@ -942,7 +943,7 @@ void prepare_buffer_fusing::run(program& p) {
auto update_scale_zp = [&](size_t kv_cache_output_idx, size_t read_value_output_idx) {
auto scales_out_layout = node.get_output_layout(false, kv_cache_output_idx);

const size_t scales_zp_concat_axis = 2;
const auto scales_zp_concat_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
padding::DynamicDimsMask info_dynamic_pad_scales;
info_dynamic_pad_scales[scales_zp_concat_axis] = 1;
scales_out_layout.data_padding._dynamic_dims_mask = info_dynamic_pad_scales;
Expand All @@ -958,7 +959,6 @@ void prepare_buffer_fusing::run(program& p) {
update_dep(gather_prim, info_dynamic_pad, 0);
}

const auto& desc = node.get_primitive();
if (desc->compressed) {
update_scale_zp(2, 1);

Expand Down
42 changes: 30 additions & 12 deletions src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {

if (desc->get_compression_zp_inputs_num() > 0) {
// Copy zero points to the new buffer if needed
execute_stage(events, instance, res_events, scale_concat_stage, zp_concat_stage);
execute_stage(events, instance, res_events, zp_concat_stage, zp_concat_stage);
}

// Perform dynamic quantization of new token data and append result to the KV-cache
Expand Down Expand Up @@ -417,15 +417,19 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
return params;
}

static kernel_params_t get_compression_scale_update_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
static kernel_params_t get_compression_scale_update_kernel_params(const kernel_impl_params& impl_param,
bool is_scale = true,
bool is_shape_agnostic = false) {
auto params = get_default_params<kernel_selector::concatenation_params>(impl_param, is_shape_agnostic);

const auto concat_axis = 2;
params.axis = convert_axis(concat_axis, impl_param.get_output_layout().get_rank());

auto inputs_count = 1;
auto comp_scale_past_layout = impl_param.input_layouts[3];
auto comp_scale_present_layout = impl_param.output_layouts[2];
const auto inputs_count = 1;
const auto input_idx = is_scale ? 3 : 4; // scale or zp
const auto output_idx = is_scale ? 2 : 3; // scale or zp
auto comp_scale_past_layout = impl_param.input_layouts[input_idx];
auto comp_scale_present_layout = impl_param.output_layouts[output_idx];

params.inputs.resize(inputs_count);
params.inputs[0] = convert_data_tensor(comp_scale_past_layout);
Expand All @@ -435,10 +439,10 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset;

std::map<size_t, size_t> in_tensor_to_offset_map = {
{0, in_offsets_map.at(3)}, // compression_scale_past
{0, in_offsets_map.at(input_idx)}, // compression_[scale/zp]_past
};
std::map<size_t, size_t> out_tensor_to_offset_map = {
{0, out_offsets_map.at(2)}, // compression_scale_present
{0, out_offsets_map.at(output_idx)}, // compression_[scale/zp]_present
};

params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map);
Expand All @@ -451,8 +455,11 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
auto concat_kernel_params = get_concat_kernel_params(impl_param, impl_param.is_dynamic());
auto& concat_kernel_selector = kernel_selector_t::Instance();
kernels_data.push_back(concat_kernel_selector.get_best_kernel(concat_kernel_params));
const bool indirect = impl_param.typed_desc<kv_cache>()->indirect;
const bool compressed = impl_param.typed_desc<kv_cache>()->compressed;

const auto desc = impl_param.typed_desc<kv_cache>();
const bool indirect = desc->indirect;
const bool compressed = desc->compressed;
const bool has_zp_input = desc->get_compression_zp_inputs_num() > 0;
if (indirect) {
auto bt_update_kernel_params = get_bt_update_kernel_params(impl_param, false);
auto& bt_update_kernel_selector = bt_kernel_selector_t::Instance();
Expand All @@ -464,9 +471,14 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
auto& dq_kernel_selector = dq_kernel_selector_t::Instance();
kernels_data.push_back(dq_kernel_selector.get_best_kernel(dq_kernel_params));

auto concat_scale_zp_kernel_params = get_compression_scale_update_kernel_params(impl_param, impl_param.is_dynamic());
auto& concat_scale_zp_kernel_selector = kernel_selector_t::Instance();
kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_scale_zp_kernel_params));
auto concat_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, true, impl_param.is_dynamic());
kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_scale_kernel_params));

if (has_zp_input) {
auto concat_zp_kernel_params = get_compression_scale_update_kernel_params(impl_param, false, impl_param.is_dynamic());
kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_zp_kernel_params));
}
}
return cldnn::make_unique<kv_cache_impl>(kernels_data);
}
Expand Down Expand Up @@ -494,9 +506,15 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
_kernels_data[concat_stage].kernels[1].skip_execution = true;

// Update dynamic quantization parameters
auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, impl_param.is_dynamic());
auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, true, impl_param.is_dynamic());
(_kernels_data[scale_concat_stage].update_dispatch_data_func)(comp_scale_kernel_params, _kernels_data[scale_concat_stage]);
_kernels_data[scale_concat_stage].kernels[0].skip_execution = impl_param._can_be_optimized || impl_param.get_input_layout(3).count() == 0;

if (impl_param.typed_desc<kv_cache>()->get_compression_zp_inputs_num() > 0) {
auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, false, impl_param.is_dynamic());
(_kernels_data[zp_concat_stage].update_dispatch_data_func)(comp_scale_kernel_params, _kernels_data[zp_concat_stage]);
_kernels_data[zp_concat_stage].kernels[0].skip_execution = impl_param._can_be_optimized || impl_param.get_input_layout(4).count() == 0;
}
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
if (has_indirect_inputs(impl_param))
data_inputs_num--;

auto has_zp_input_buffers = false;
auto has_zp_input_buffers = desc->get_compression_zp_inputs_num() > 0;
if (desc->is_kv_compressed) {
data_inputs_num -= 2; // key and value compression scales are handled separately

Expand Down
5 changes: 2 additions & 3 deletions src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ class typed_primitive_inst<kv_cache> : public typed_primitive_inst_base<kv_cache
return sequence_axis >= 0 ? sequence_axis : past_layout_rank + sequence_axis;
}

static int64_t get_scale_zp_sequence_axis() {
// The order of scales and zero points is fixed, so use constant axis
const auto scale_zp_concat_axis = 2;
static int64_t get_scale_zp_sequence_axis(int64_t sequence_axis, const kv_cache::QuantizationAttributes& quantization_attrs) {
const auto scale_zp_concat_axis = quantization_attrs.scales_zp_output_order[sequence_axis];
return scale_zp_concat_axis;
}

Expand Down
28 changes: 16 additions & 12 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,13 @@ void primitive_inst::update_shape() {
_impl_params->state_layouts.resize(compressed_cache_variable->has_zp_state() ? 3 : 2);

auto scales_state = compressed_cache_variable->get_compression_scale_state();
auto new_scales_layout = compressed_cache_variable->get_compression_scale_state()->get_layout();
auto new_scales_layout = scales_state->get_layout();
update_state_layout(*scales_state, new_scales_layout, 1);

if (compressed_cache_variable->has_zp_state()) {
auto scales_state = compressed_cache_variable->get_compression_zp_state();
auto new_zp_layout = compressed_cache_variable->get_compression_zp_state()->get_layout();
update_state_layout(*scales_state, new_zp_layout, 2);
auto zp_state = compressed_cache_variable->get_compression_zp_state();
auto new_zp_layout = zp_state->get_layout();
update_state_layout(*zp_state, new_zp_layout, 2);
}
}
}
Expand Down Expand Up @@ -843,7 +843,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
auto prealloc_shape = updated_layouts[i].get_shape();
const auto shape_rank = prealloc_shape.size();
const auto seq_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, shape_rank)
: kv_cache_inst::get_scale_zp_sequence_axis();
: kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);

prealloc_shape[seq_axis] += tmp_prealloc_count;
required_buffer_size = std::accumulate(prealloc_shape.begin(), prealloc_shape.end(), size_t(1), std::multiplies<size_t>());
Expand Down Expand Up @@ -875,7 +875,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
const auto& desc = _node->as<kv_cache>().get_primitive();
const auto shape_rank = updated_layouts[i].get_shape().size();
const auto seq_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, shape_rank)
: kv_cache_inst::get_scale_zp_sequence_axis();
: kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);

prealloc_info = sp.predict_preallocation_shape(id(), updated_layouts[i], false, i, tmp_prealloc_count, seq_axis);
} else {
Expand All @@ -899,7 +899,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
auto& present_layout = _impl_params->output_layouts[i];
const auto present_layout_rank = present_layout.get_partial_shape().size();
const auto sequence_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, present_layout_rank)
: kv_cache_inst::get_scale_zp_sequence_axis();;
: kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);

auto max_pad = kv_cache_inst::get_max_pad(present_layout,
_max_output_layout_count[i],
Expand Down Expand Up @@ -970,7 +970,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
if (max_pad > 0) {
if (auto compressed_cache_variable = dynamic_cast<ov::intel_gpu::VariableStateIndirectKVCacheCompressed*>(&variable)) {
auto present_scales_layout = _impl_params->output_layouts[2];
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis();;
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);

// In case of compressed KV-cache, calling update_impl for each iteration
// because of scales layout [batch, num_heads, seq_len, head_size], which requires proper
Expand All @@ -982,8 +982,9 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
compressed_cache_variable->get_compression_scale_state()->set_memory(_outputs[2], present_scales_layout);
if (compressed_cache_variable->has_zp_state()) {
auto present_zp_layout = present_scales_layout;
present_zp_layout.data_type = _impl_params->output_layouts[3].data_type;

_impl_params->output_layouts[3] = present_scales_layout;
_impl_params->output_layouts[3] = present_zp_layout;
compressed_cache_variable->get_compression_zp_state()->set_memory(_outputs[3], present_zp_layout);
}
}
Expand Down Expand Up @@ -1361,7 +1362,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
if (desc->compressed) {
auto compressed_cache_variable = dynamic_cast<ov::intel_gpu::VariableStateIndirectKVCacheCompressed*>(&variable);
auto& present_scales_layout = _impl_params->output_layouts[2];
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis();
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
kv_cache_inst::update_pad(present_scales_layout, max_pad - new_seq_len, sequence_axis);
GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id()
<< " Updated present_scale_layout's pad : " << present_scales_layout.to_string() << std::endl;
Expand All @@ -1373,7 +1374,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id()
<< " Updated present_zp_layout's pad : " << present_scales_layout.to_string() << std::endl;

compressed_cache_variable->get_compression_zp_state()->set_layout(present_scales_layout);
compressed_cache_variable->get_compression_zp_state()->set_layout(present_zp_layout);
}
}

Expand All @@ -1385,7 +1386,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {

if (desc->compressed) {
auto& past_scale_layout = _impl_params->input_layouts[3];
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis();
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
kv_cache_inst::update_pad(past_scale_layout, max_pad, sequence_axis);

if (desc->get_compression_zp_inputs_num() > 0) {
Expand Down Expand Up @@ -2092,6 +2093,9 @@ primitive_inst::primitive_inst(network & network, program_node const& node, bool
_outputs = allocate_outputs();
}
}
if (_node) {
GPU_DEBUG_TRACE_DETAIL << _node->type()->to_string(*_node) << "\n";
}
_impls_factory = std::make_shared<ImplementationsFactory>(_node);
_impl_params->strm = _network.get_stream_ptr();
for (size_t i = 0; i < get_node().get_output_layouts().size(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ KERNEL(dynamic_quantize_gpu_kv_cache)(
min_value = work_group_reduce_min(min_value);
max_value = work_group_reduce_max(max_value);
ACCUMULATOR_TYPE scale = (ACCUMULATOR_TYPE)((CHAR_MAX - CHAR_MIN) / (max_value - min_value));
ACCUMULATOR_TYPE zp = (ACCUMULATOR_TYPE)(-min_value * scale) - CHAR_MAX;
ACCUMULATOR_TYPE zp = (ACCUMULATOR_TYPE)(-min_value * scale) + CHAR_MIN;
#else
max_value = work_group_reduce_max(max_value);
ACCUMULATOR_TYPE scale = 127.0h / max_value;
Expand Down Expand Up @@ -112,7 +112,11 @@ KERNEL(dynamic_quantize_gpu_kv_cache)(
#if GROUP_SCALES_WITH_ZP
output_scale[scale_idx + 1] = zp;
#else
#if OUTPUT2_IS_FP
output_zp[scale_idx] = zp;
#else
output_zp[scale_idx] = convert_char_rte(zp);
#endif
#endif
#else
output_scale[scale_idx] = 1.0h / scale;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ KERNEL(dynamic_quantize_gpu_ref)(
# if UNSIGNED_OUTPUT
OUTPUT1_TYPE zp = (OUTPUT1_TYPE)(-min_val * scale);
# else // !UNSIGNED_OUTPUT
OUTPUT1_TYPE zp = (OUTPUT1_TYPE)(-min_val * scale) - CHAR_MAX;
OUTPUT1_TYPE zp = (OUTPUT1_TYPE)(-min_val * scale) + CHAR_MIN;
# endif
#else // !ASYMMETRIC_QUANTIZATION
max_val = work_group_reduce_max(max_val);
Expand Down Expand Up @@ -153,6 +153,12 @@ KERNEL(dynamic_quantize_gpu_ref)(
#if ASYMMETRIC_QUANTIZATION && GROUP_SCALES_WITH_ZP
output_scale[scale_idx + 1] = zp;
#elif ASYMMETRIC_QUANTIZATION
output_zp[scale_idx] = convert_uchar_rte(zp);
#if OUTPUT2_IS_FP
output_zp[scale_idx] = zp;
#elif UNSIGNED_OUTPUT
output_zp[scale_idx] = convert_uchar_rte(zp);
#else
output_zp[scale_idx] = convert_char_rte(zp);
#endif
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#define KEY_OFF(x0, x1, x2, x3) _4D_OFF(KEY, x0, x1, x2, x3)
#define VAL_OFF(x0, x1, x2, x3) _4D_OFF(VAL, x0, x1, x2, x3)
#define MSK_OFF(x0, x1, x2, x3) _4D_OFF(MSK, x0, x1, x2, x3)
#define KEY_COMP_OFF(x0, x1, x2, x3) _4D_OFF(KEY_COMP, x0, x1, x2, x3)
#define VAL_COMP_OFF(x0, x1, x2, x3) _4D_OFF(VAL_COMP, x0, x1, x2, x3)

#define DST_OFF(x0, x1, d, h, w) \
(((x0) % DST_B0) * DST_SB0 + ((x0) / DST_B0) * DST_S0 \
Expand Down
Loading

0 comments on commit 5df2d53

Please sign in to comment.