diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp index 0e9f3e14eb3f2e..7db9c2c0d59419 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp @@ -926,6 +926,7 @@ void prepare_buffer_fusing::run(program& p) { if (kv_out_layout.is_dynamic()) { // set dynamic pad dims for shape agnostic kernel + const auto& desc = node.get_primitive(); padding::DynamicDimsMask info_dynamic_pad; info_dynamic_pad[concat_axis] = 1; kv_out_layout.data_padding._dynamic_dims_mask = info_dynamic_pad; @@ -942,7 +943,7 @@ void prepare_buffer_fusing::run(program& p) { auto update_scale_zp = [&](size_t kv_cache_output_idx, size_t read_value_output_idx) { auto scales_out_layout = node.get_output_layout(false, kv_cache_output_idx); - const size_t scales_zp_concat_axis = 2; + const auto scales_zp_concat_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); padding::DynamicDimsMask info_dynamic_pad_scales; info_dynamic_pad_scales[scales_zp_concat_axis] = 1; scales_out_layout.data_padding._dynamic_dims_mask = info_dynamic_pad_scales; @@ -958,7 +959,6 @@ void prepare_buffer_fusing::run(program& p) { update_dep(gather_prim, info_dynamic_pad, 0); } - const auto& desc = node.get_primitive(); if (desc->compressed) { update_scale_zp(2, 1); diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp index fef2a3c51ee821..1ffbfbbfbade37 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp @@ -230,7 +230,7 @@ struct kv_cache_impl : multi_stage_primitive { if (desc->get_compression_zp_inputs_num() > 0) { // Copy zero points to the new buffer if needed - execute_stage(events, instance, res_events, scale_concat_stage, zp_concat_stage); + execute_stage(events, instance, res_events, zp_concat_stage, zp_concat_stage); } // Perform dynamic quantization of new token data and append result to the KV-cache @@ -417,15 +417,19 @@ struct kv_cache_impl : multi_stage_primitive { return params; } - static kernel_params_t get_compression_scale_update_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) { + static kernel_params_t get_compression_scale_update_kernel_params(const kernel_impl_params& impl_param, + bool is_scale = true, + bool is_shape_agnostic = false) { auto params = get_default_params(impl_param, is_shape_agnostic); const auto concat_axis = 2; params.axis = convert_axis(concat_axis, impl_param.get_output_layout().get_rank()); - auto inputs_count = 1; - auto comp_scale_past_layout = impl_param.input_layouts[3]; - auto comp_scale_present_layout = impl_param.output_layouts[2]; + const auto inputs_count = 1; + const auto input_idx = is_scale ? 3 : 4; // scale or zp + const auto output_idx = is_scale ? 2 : 3; // scale or zp + auto comp_scale_past_layout = impl_param.input_layouts[input_idx]; + auto comp_scale_present_layout = impl_param.output_layouts[output_idx]; params.inputs.resize(inputs_count); params.inputs[0] = convert_data_tensor(comp_scale_past_layout); @@ -435,10 +439,10 @@ struct kv_cache_impl : multi_stage_primitive { const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset; std::map in_tensor_to_offset_map = { - {0, in_offsets_map.at(3)}, // compression_scale_past + {0, in_offsets_map.at(input_idx)}, // compression_[scale/zp]_past }; std::map out_tensor_to_offset_map = { - {0, out_offsets_map.at(2)}, // compression_scale_present + {0, out_offsets_map.at(output_idx)}, // compression_[scale/zp]_present }; params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map); @@ -451,8 +455,11 @@ struct kv_cache_impl : multi_stage_primitive { auto concat_kernel_params = get_concat_kernel_params(impl_param, impl_param.is_dynamic()); auto& concat_kernel_selector = kernel_selector_t::Instance(); kernels_data.push_back(concat_kernel_selector.get_best_kernel(concat_kernel_params)); - const bool indirect = impl_param.typed_desc()->indirect; - const bool compressed = impl_param.typed_desc()->compressed; + + const auto desc = impl_param.typed_desc(); + const bool indirect = desc->indirect; + const bool compressed = desc->compressed; + const bool has_zp_input = desc->get_compression_zp_inputs_num() > 0; if (indirect) { auto bt_update_kernel_params = get_bt_update_kernel_params(impl_param, false); auto& bt_update_kernel_selector = bt_kernel_selector_t::Instance(); @@ -464,9 +471,14 @@ struct kv_cache_impl : multi_stage_primitive { auto& dq_kernel_selector = dq_kernel_selector_t::Instance(); kernels_data.push_back(dq_kernel_selector.get_best_kernel(dq_kernel_params)); - auto concat_scale_zp_kernel_params = get_compression_scale_update_kernel_params(impl_param, impl_param.is_dynamic()); auto& concat_scale_zp_kernel_selector = kernel_selector_t::Instance(); - kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_scale_zp_kernel_params)); + auto concat_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, true, impl_param.is_dynamic()); + kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_scale_kernel_params)); + + if (has_zp_input) { + auto concat_zp_kernel_params = get_compression_scale_update_kernel_params(impl_param, false, impl_param.is_dynamic()); + kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_zp_kernel_params)); + } } return cldnn::make_unique(kernels_data); } @@ -494,9 +506,15 @@ struct kv_cache_impl : multi_stage_primitive { _kernels_data[concat_stage].kernels[1].skip_execution = true; // Update dynamic quantization parameters - auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, impl_param.is_dynamic()); + auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, true, impl_param.is_dynamic()); (_kernels_data[scale_concat_stage].update_dispatch_data_func)(comp_scale_kernel_params, _kernels_data[scale_concat_stage]); _kernels_data[scale_concat_stage].kernels[0].skip_execution = impl_param._can_be_optimized || impl_param.get_input_layout(3).count() == 0; + + if (impl_param.typed_desc()->get_compression_zp_inputs_num() > 0) { + auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, false, impl_param.is_dynamic()); + (_kernels_data[zp_concat_stage].update_dispatch_data_func)(comp_scale_kernel_params, _kernels_data[zp_concat_stage]); + _kernels_data[zp_concat_stage].kernels[0].skip_execution = impl_param._can_be_optimized || impl_param.get_input_layout(4).count() == 0; + } } } }; diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp index dad93d94946490..86b49484282238 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp @@ -287,7 +287,7 @@ struct scaled_dot_product_attention_impl : multi_stage_primitiveget_compression_zp_inputs_num() > 0; if (desc->is_kv_compressed) { data_inputs_num -= 2; // key and value compression scales are handled separately diff --git a/src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h b/src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h index da0a9397433f89..e95e2e94ff4ab0 100644 --- a/src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h @@ -62,9 +62,8 @@ class typed_primitive_inst : public typed_primitive_inst_base= 0 ? sequence_axis : past_layout_rank + sequence_axis; } - static int64_t get_scale_zp_sequence_axis() { - // The order of scales and zero points is fixed, so use constant axis - const auto scale_zp_concat_axis = 2; + static int64_t get_scale_zp_sequence_axis(int64_t sequence_axis, const kv_cache::QuantizationAttributes& quantization_attrs) { + const auto scale_zp_concat_axis = quantization_attrs.scales_zp_output_order[sequence_axis]; return scale_zp_concat_axis; } diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index 3712202f4926c8..6e1af3f5429283 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -338,13 +338,13 @@ void primitive_inst::update_shape() { _impl_params->state_layouts.resize(compressed_cache_variable->has_zp_state() ? 3 : 2); auto scales_state = compressed_cache_variable->get_compression_scale_state(); - auto new_scales_layout = compressed_cache_variable->get_compression_scale_state()->get_layout(); + auto new_scales_layout = scales_state->get_layout(); update_state_layout(*scales_state, new_scales_layout, 1); if (compressed_cache_variable->has_zp_state()) { - auto scales_state = compressed_cache_variable->get_compression_zp_state(); - auto new_zp_layout = compressed_cache_variable->get_compression_zp_state()->get_layout(); - update_state_layout(*scales_state, new_zp_layout, 2); + auto zp_state = compressed_cache_variable->get_compression_zp_state(); + auto new_zp_layout = zp_state->get_layout(); + update_state_layout(*zp_state, new_zp_layout, 2); } } } @@ -851,7 +851,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) { auto prealloc_shape = updated_layouts[i].get_shape(); const auto shape_rank = prealloc_shape.size(); const auto seq_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, shape_rank) - : kv_cache_inst::get_scale_zp_sequence_axis(); + : kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); prealloc_shape[seq_axis] += tmp_prealloc_count; required_buffer_size = std::accumulate(prealloc_shape.begin(), prealloc_shape.end(), size_t(1), std::multiplies()); @@ -883,7 +883,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) { const auto& desc = _node->as().get_primitive(); const auto shape_rank = updated_layouts[i].get_shape().size(); const auto seq_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, shape_rank) - : kv_cache_inst::get_scale_zp_sequence_axis(); + : kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); prealloc_info = sp.predict_preallocation_shape(id(), updated_layouts[i], false, i, tmp_prealloc_count, seq_axis); } else { @@ -907,7 +907,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) { auto& present_layout = _impl_params->output_layouts[i]; const auto present_layout_rank = present_layout.get_partial_shape().size(); const auto sequence_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, present_layout_rank) - : kv_cache_inst::get_scale_zp_sequence_axis();; + : kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); auto max_pad = kv_cache_inst::get_max_pad(present_layout, _max_output_layout_count[i], @@ -978,7 +978,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) { if (max_pad > 0) { if (auto compressed_cache_variable = dynamic_cast(&variable)) { auto present_scales_layout = _impl_params->output_layouts[2]; - const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis();; + const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); // In case of compressed KV-cache, calling update_impl for each iteration // because of scales layout [batch, num_heads, seq_len, head_size], which requires proper @@ -990,8 +990,9 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) { compressed_cache_variable->get_compression_scale_state()->set_memory(_outputs[2], present_scales_layout); if (compressed_cache_variable->has_zp_state()) { auto present_zp_layout = present_scales_layout; + present_zp_layout.data_type = _impl_params->output_layouts[3].data_type; - _impl_params->output_layouts[3] = present_scales_layout; + _impl_params->output_layouts[3] = present_zp_layout; compressed_cache_variable->get_compression_zp_state()->set_memory(_outputs[3], present_zp_layout); } } @@ -1373,7 +1374,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() { if (desc->compressed) { auto compressed_cache_variable = dynamic_cast(&variable); auto& present_scales_layout = _impl_params->output_layouts[2]; - const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(); + const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); kv_cache_inst::update_pad(present_scales_layout, max_pad - new_seq_len, sequence_axis); GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id() << " Updated present_scale_layout's pad : " << present_scales_layout.to_string() << std::endl; @@ -1385,7 +1386,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() { GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id() << " Updated present_zp_layout's pad : " << present_scales_layout.to_string() << std::endl; - compressed_cache_variable->get_compression_zp_state()->set_layout(present_scales_layout); + compressed_cache_variable->get_compression_zp_state()->set_layout(present_zp_layout); } } @@ -1397,7 +1398,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() { if (desc->compressed) { auto& past_scale_layout = _impl_params->input_layouts[3]; - const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(); + const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); kv_cache_inst::update_pad(past_scale_layout, max_pad, sequence_axis); if (desc->get_compression_zp_inputs_num() > 0) { @@ -2104,6 +2105,9 @@ primitive_inst::primitive_inst(network & network, program_node const& node, bool _outputs = allocate_outputs(); } } + if (_node) { + GPU_DEBUG_TRACE_DETAIL << _node->type()->to_string(*_node) << "\n"; + } _impls_factory = std::make_shared(_node); _impl_params->strm = _network.get_stream_ptr(); for (size_t i = 0; i < get_node().get_output_layouts().size(); ++i) { diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_kv_cache.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_kv_cache.cl index 16add2e0397d32..169e7cc62635b8 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_kv_cache.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_kv_cache.cl @@ -87,12 +87,14 @@ KERNEL(dynamic_quantize_gpu_kv_cache)( #if ASYMMETRIC_QUANTIZATION min_value = work_group_reduce_min(min_value); max_value = work_group_reduce_max(max_value); + // If the range of input data is zero, it is adjusted to the minimum value(0.001). - half diff_value = max_value == min_value ? (grp_max) : (max_value - min_value); + ACCUMULATOR_TYPE diff_value = max_value == min_value ? (grp_max) : (max_value - min_value); ACCUMULATOR_TYPE scale_tmp = (ACCUMULATOR_TYPE)((CHAR_MAX - CHAR_MIN) / diff_value); - ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_value * scale_tmp) - CHAR_MAX; + ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_value * scale_tmp) + CHAR_MIN; OUTPUT1_TYPE scale = (OUTPUT1_TYPE)(scale_tmp); OUTPUT1_TYPE zp = (OUTPUT1_TYPE)(zp_tmp); + #else max_value = work_group_reduce_max(max_value); OUTPUT1_TYPE scale = 127.0h / max_value; @@ -120,7 +122,13 @@ KERNEL(dynamic_quantize_gpu_kv_cache)( #if GROUP_SCALES_WITH_ZP output_scale[scale_idx + 1] = zp; #else + + #if OUTPUT2_IS_FP + output_zp[scale_idx] = zp; + #else output_zp[scale_idx] = convert_char_rte(zp); + #endif + #endif #else output_scale[scale_idx] = 1.0h / scale; diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_ref.cl index 72dc057d44a040..236fe4c9dab684 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_ref.cl @@ -104,12 +104,12 @@ KERNEL(dynamic_quantize_gpu_ref)( #if ASYMMETRIC_QUANTIZATION // If the range of input data is zero, it is adjusted to the minimum value(0.001). - half diff_value = max_val == min_val ? (grp_max) : (max_val - min_val); + ACCUMULATOR_TYPE diff_value = max_val == min_val ? (grp_max) : (max_val - min_val); ACCUMULATOR_TYPE scale_tmp = (ACCUMULATOR_TYPE)((CHAR_MAX - CHAR_MIN) / diff_value); # if UNSIGNED_OUTPUT ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_val * scale_tmp); # else // !UNSIGNED_OUTPUT - ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_val * scale_tmp) - CHAR_MAX; + ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_val * scale_tmp) + CHAR_MIN; # endif OUTPUT1_TYPE scale = (OUTPUT1_TYPE)(scale_tmp); OUTPUT1_TYPE zp = (OUTPUT1_TYPE)(zp_tmp); @@ -161,6 +161,12 @@ KERNEL(dynamic_quantize_gpu_ref)( #if ASYMMETRIC_QUANTIZATION && GROUP_SCALES_WITH_ZP output_scale[scale_idx + 1] = zp; #elif ASYMMETRIC_QUANTIZATION - output_zp[scale_idx] = convert_uchar_rte(zp); + #if OUTPUT2_IS_FP + output_zp[scale_idx] = zp; + #elif UNSIGNED_OUTPUT + output_zp[scale_idx] = convert_uchar_rte(zp); + #else + output_zp[scale_idx] = convert_char_rte(zp); + #endif #endif } diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/batch_headers/sdpa_utils.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/batch_headers/sdpa_utils.cl index 5943f23251bb7a..36c9741f3f3c7a 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/batch_headers/sdpa_utils.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/batch_headers/sdpa_utils.cl @@ -27,6 +27,8 @@ #define KEY_OFF(x0, x1, x2, x3) _4D_OFF(KEY, x0, x1, x2, x3) #define VAL_OFF(x0, x1, x2, x3) _4D_OFF(VAL, x0, x1, x2, x3) #define MSK_OFF(x0, x1, x2, x3) _4D_OFF(MSK, x0, x1, x2, x3) +#define KEY_COMP_OFF(x0, x1, x2, x3) _4D_OFF(KEY_COMP, x0, x1, x2, x3) +#define VAL_COMP_OFF(x0, x1, x2, x3) _4D_OFF(VAL_COMP, x0, x1, x2, x3) #define DST_OFF(x0, x1, d, h, w) \ (((x0) % DST_B0) * DST_SB0 + ((x0) / DST_B0) * DST_S0 \ diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_micro.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_micro.cl index 1584dffe95a3c3..b50410a3cf761f 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_micro.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_micro.cl @@ -18,6 +18,12 @@ #include "include/batch_headers/sdpa_utils.cl" #include "include/batch_headers/tile_ops.cl" +/* The quantization parameter may be unique for each token/element */ +#define QUANTIZE_2D 2 + +/* The quantization parameter shares the same value across the work-group */ +#define QUANTIZE_COMMON 3 + #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define DIV_UP(x, y) (((x) + (y)-1) / (y)) @@ -133,7 +139,9 @@ DECLARE_2D_TILE_RSELECT(a_scale_tile_type, SUBGROUP_SIZE, ugemm_vs_sg_tile_n, 1, __attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG - const global half *K, const global half *Q, const global half *V, + const global KEY_DATA_T *K, + const global QRY_DATA_T *Q, + const global VAL_DATA_T *V, global half *A, #if WITH_ATTN_MASK const global half *msk, @@ -141,10 +149,18 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG #if WITH_SCALE global SCALE_DATA_T *scale_ptr, #endif - int d, int k, int q) { + int d, int k, int q +#ifdef KV_COMPRESSED + , const global KEY_ATTR_SCALES_DATA_T *K_scales + , const global KEY_ATTR_ZP_DATA_T *K_zp + , const global VAL_ATTR_SCALES_DATA_T *V_scales + , const global VAL_ATTR_ZP_DATA_T *V_zp +#endif + ) { uint sg_ij = sub_group_broadcast(get_local_id(1), 0); uint b0 = get_group_id(1); uint b1 = get_group_id(2); + uint b0_kv = b0 / KV_GROUP_SIZE; uint wg_j0 = get_group_id(0) * ugemm_kq_wg_tile_n; @@ -154,6 +170,13 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG uint ldv = VAL_S2; uint lda = DST_S2; +#if KEY_SCALES || KEY_ZERO_POINTS + uint ldkq = DIV_UP(d, KEY_GROUP_SIZE); +#endif +#if VAL_SCALES || VAL_ZERO_POINTS + uint ldvq = DIV_UP(d, VAL_GROUP_SIZE); +#endif + /* Subgroup IDs for each GEMM */ uint sg_i_kq = sg_ij % ugemm_kq_sg_per_wg_m; uint sg_j_kq = sg_ij / ugemm_kq_sg_per_wg_m; @@ -183,11 +206,30 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG const bool need_sum_barrier = (ugemm_vs_barrier_count == 0); /* Locate K/Q/V/A matrices within batch */ - K += KEY_OFF(b1, (b0 / KV_GROUP_SIZE), 0, 0) + INPUT1_OFFSET; - Q += QRY_OFF(b1, b0, 0, 0) + INPUT0_OFFSET; - V += VAL_OFF(b1, (b0 / KV_GROUP_SIZE), 0, 0) + INPUT2_OFFSET; + K += (KEY_OFF(b1, b0_kv, 0, 0) + INPUT1_OFFSET) / KEY_ELEMENTS_PER_BYTE; + Q += (QRY_OFF(b1, b0, 0, 0) + INPUT0_OFFSET); + V += (VAL_OFF(b1, b0_kv, 0, 0) + INPUT2_OFFSET) / VAL_ELEMENTS_PER_BYTE; A += DST_OFF(b1, b0, 0, 0, 0); +#if KEY_SCALES + K_scales += KEY_COMP_OFF(b1, b0_kv, 0, 0); +#endif +#if KEY_SCALES == QUANTIZE_COMMON + float k_scale = convert_float(*K_scales); +#endif +#if KEY_ZERO_POINTS + K_zp += KEY_COMP_OFF(b1, b0_kv, 0, 0) / KEY_ZP_ELEMENTS_PER_BYTE; +#endif +#if VAL_SCALES + V_scales += VAL_COMP_OFF(b1, b0_kv, 0, 0); +#endif +#if VAL_SCALES == QUANTIZE_COMMON + float v_scale = convert_float(*V_scales); +#endif +#if VAL_ZERO_POINTS + V_zp += VAL_COMP_OFF(b1, b0_kv, 0, 0) / VAL_ZP_ELEMENTS_PER_BYTE; +#endif + __builtin_assume_aligned(K, K_ALIGN); __builtin_assume_aligned(Q, Q_ALIGN); __builtin_assume_aligned(V, V_ALIGN); @@ -283,7 +325,25 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG /* Calculate S = (K^T) * Q */ s_tile_type S_tile = ugemm_kq(K, ldk, Q_slm, D_MAX, k, ugemm_kq_wg_tile_n, d, k0, - 0, 0, sg_i_kq, sg_j_kq, (local char *)ugemm_slm); + 0, 0, sg_i_kq, sg_j_kq, (local char *)ugemm_slm +#if KEY_SCALES == QUANTIZE_2D + , + K_scales +#endif +#if KEY_ZERO_POINTS + , + K_zp +#endif +#if (KEY_SCALES == QUANTIZE_2D) || KEY_ZERO_POINTS + , + ldkq +#endif + ); + +#if KEY_SCALES == QUANTIZE_COMMON +#define k_scale_op(x) ((x)*k_scale) + tile_elementwise(S_tile, k_scale_op); +#endif /* Apply attention mask */ #if WITH_ATTN_MASK @@ -419,10 +479,31 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG /* Accumulate A += V * S */ int k_chunk = min(k - k0, ugemm_kq_wg_tile_m); - a_tile_type A_tile1 = ugemm_vs(V, ldv, S_slm, ugemm_kq_wg_tile_m, d, - ugemm_kq_wg_tile_n, k_chunk, 0, 0, 0, sg_i_vs, sg_j_vs, - (local char *)ugemm_slm); - V += ldv * ugemm_kq_wg_tile_m; + + a_tile_type A_tile1 = ugemm_vs( + V, ldv, S_slm, ugemm_kq_wg_tile_m, d, ugemm_kq_wg_tile_n, + k_chunk, 0, 0, 0, sg_i_vs, sg_j_vs, (local char *)ugemm_slm +#if VAL_SCALES == QUANTIZE_2D + , + V_scales +#endif +#if VAL_ZERO_POINTS + , + V_zp +#endif +#if (VAL_SCALES == QUANTIZE_2D) || VAL_ZERO_POINTS + , + ldvq +#endif + ); + + V += ldv * ugemm_kq_wg_tile_m / VAL_ELEMENTS_PER_BYTE; +#if VAL_SCALES == QUANTIZE_2D + V_scales += ldvq * ugemm_kq_wg_tile_m; +#endif +#if VAL_ZERO_POINTS == QUANTIZE_2D + V_zp += ldvq * ugemm_kq_wg_tile_m / VAL_ZP_ELEMENTS_PER_BYTE; +#endif tile_binary(A_tile, A_tile1, binary_add); } @@ -440,6 +521,11 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG tile_binary(A_scale_tile, A_scale_tile_load, binary_add); } +#if VAL_SCALES == QUANTIZE_COMMON +#define v_scale_op(x) ((x)*v_scale) + tile_elementwise(A_tile, v_scale_op); +#endif + /* Rescale by 1 / (column sums) */ tile_elementwise(A_scale_tile, native_vrecip); tile_hbroadcast_mul(&A_tile, A_scale_tile); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_micro.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_micro.cpp index 467dd71da37944..028c95b77c9b06 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_micro.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_micro.cpp @@ -47,6 +47,8 @@ micro::Type convert_type(Datatype t) { switch (t) { case Datatype::F32: return micro::Type::f32; case Datatype::F16: return micro::Type::f16; + case Datatype::INT8: return micro::Type::s8; + case Datatype::UINT8: return micro::Type::u8; default: break; } OPENVINO_THROW("Unsupported dt: ", toString(t)); @@ -83,16 +85,26 @@ sdpa_config_t xehpg_h32_s64 = {16, 16, 16, 8, 4, 4, 2, 8}; sdpa_config_t xehpg_h32_s32 = {8, 8, 8, 8, 4, 4, 4, 4}; sdpa_config_t xehpg_h32_2nd = {8, 32, 16, 8, 8, 1, 2, 4}; +sdpa_config_t xehpg_q_h32 = {32, 16, 16, 16, 2, 8, 2, 8}; +sdpa_config_t xehpg_q_h32_2nd = {32, 16, 8, 8, 8, 1, 4, 2}; + sdpa_config_t xehpg_h64 = {32, 16, 16, 16, 4, 8, 4, 8}; sdpa_config_t xehpg_h64_s128 = {16, 16, 16, 16, 4, 8, 4, 8}; sdpa_config_t xehpg_h64_s64 = {32, 16, 16, 8, 8, 4, 4, 8}; sdpa_config_t xehpg_h64_2nd = {8, 16, 16, 8, 8, 1, 4, 2}; +sdpa_config_t xehpg_q_h64 = {32, 16, 16, 16, 4, 4, 4, 4}; +sdpa_config_t xehpg_q_h64_2nd = {16, 16, 8, 8, 16, 1, 8, 2}; + sdpa_config_t xehpg_h128 = {16, 16, 32, 8, 8, 4, 4, 8}; sdpa_config_t xehpg_h128_s32 = {16, 16, 16, 8, 16, 2, 8, 4}; sdpa_config_t xehpg_h128_2nd = {8, 16, 16, 8, 16, 1, 8, 2}; sdpa_config_t xehpg_h128_s256_2nd = {8, 16, 32, 8, 8, 1, 4, 2}; +sdpa_config_t xehpg_q_h128 = {32, 16, 16, 16, 8, 4, 8, 4}; +sdpa_config_t xehpg_q_h128_2nd = {32, 16, 16, 8, 16, 1, 8, 2}; +sdpa_config_t xehpg_q_h128_s64_2nd = {16, 16, 16, 8, 16, 1, 8, 2}; + sdpa_config_t xehpg_h256 = {16, 16, 32, 8, 16, 2, 8, 4}; sdpa_config_t xehpg_h256_s128 = {8, 16, 32, 16, 8, 4, 8, 4}; sdpa_config_t xehpg_h256_s32 = {8, 16, 32, 8, 16, 2, 8, 4}; @@ -110,28 +122,52 @@ sdpa_config_t xehpc_h64_s32 = {16, 16, 16, 16, 4, 2, 4, 2}; sdpa_config_t xehpc_h64_2nd = {32, 32, 32, 16, 4, 1, 2, 2}; sdpa_config_t xehpc_h64_s64_2nd = {16, 16, 16, 16, 4, 1, 4, 1}; +sdpa_config_t xehpc_q_h64 = {16, 64, 32, 16, 8, 4, 2, 16}; + sdpa_config_t xehpc_h128 = {16, 64, 32, 16, 16, 2, 4, 8}; sdpa_config_t xehpc_h128_s64 = {16, 32, 32, 32, 4, 2, 4, 2}; sdpa_config_t xehpc_h128_s32 = {16, 16, 16, 16, 8, 2, 8, 2}; sdpa_config_t xehpc_h128_2nd = {32, 32, 32, 16, 8, 1, 4, 2}; +sdpa_config_t xehpc_q_h128 = {16, 64, 16, 32, 16, 2, 8, 4}; +sdpa_config_t xehpc_q_h128_s64 = {16, 16, 32, 16, 4, 4, 4, 4}; +sdpa_config_t xehpc_q_h128_s32 = {16, 16, 32, 16, 4, 2, 4, 2}; +sdpa_config_t xehpc_q_h128_2nd = {32, 32, 16, 32, 4, 1, 4, 1}; +sdpa_config_t xehpc_q_h128_s32_2nd = {16, 32, 16, 16, 8, 1, 4, 2}; + sdpa_config_t xehpc_h256 = {16, 32, 32, 32, 8, 4, 8, 4}; sdpa_config_t xehpc_h256_s64 = {16, 32, 32, 32, 8, 1, 8, 1}; sdpa_config_t xehpc_h256_2nd = {16, 16, 16, 16, 16, 1, 16, 1}; -sdpa_config_t *choose_config_xehpg(int head_size, int seq, bool thin_q) { +sdpa_config_t *choose_config_xehpg(int head_size, int seq, bool thin_q, bool quantized) { if (head_size <= 32) { + if (quantized && seq >= 128) { + if (thin_q) return &xehpg_q_h32_2nd; + return &xehpg_q_h32; + } if (thin_q) return &xehpg_h32_2nd; if (seq <= 32) return &xehpg_h32_s32; if (seq <= 64) return &xehpg_h32_s64; if (seq <= 256) return &xehpg_h32_s256; return &xehpg_h32; } else if (head_size <= 64) { + if (quantized) { + if (thin_q) return &xehpg_q_h64_2nd; + return &xehpg_q_h64; + } if (thin_q) return &xehpg_h64_2nd; if (seq <= 64) return &xehpg_h64_s64; if (seq <= 128) return &xehpg_h64_s128; return &xehpg_h64; } else if (head_size <= 128) { + if (quantized) { + if (thin_q) { + if (seq <= 64) return &xehpg_q_h128_s64_2nd; + return &xehpg_q_h128_2nd; + } + if (seq <= 32) return &xehpg_h128_s32; + return &xehpg_q_h128; + } if (thin_q) { if (seq <= 256) return &xehpg_h128_s256_2nd; return &xehpg_h128_2nd; @@ -151,7 +187,7 @@ sdpa_config_t *choose_config_xehpg(int head_size, int seq, bool thin_q) { return nullptr; } -sdpa_config_t *choose_config_xehpc(int head_size, int seq, bool thin_q) { +sdpa_config_t *choose_config_xehpc(int head_size, int seq, bool thin_q, bool quantized) { if (head_size <= 32) { if (thin_q) return &xehpc_h32_2nd; if (seq <= 32) return &xehpc_h32_s32; @@ -161,10 +197,20 @@ sdpa_config_t *choose_config_xehpc(int head_size, int seq, bool thin_q) { if (seq <= 64) return &xehpc_h64_s64_2nd; return &xehpc_h64_2nd; } + if (quantized && seq >= 256) return &xehpc_q_h64; if (seq <= 32) return &xehpc_h64_s32; if (seq <= 64) return &xehpc_h64_s64; return &xehpc_h64; } else if (head_size <= 128) { + if (quantized) { + if (thin_q) { + if (seq <= 32) return &xehpc_q_h128_s32_2nd; + return &xehpc_q_h128_2nd; + } + if (seq <= 32) return &xehpc_q_h128_s32; + if (seq <= 64) return &xehpc_q_h128_s64; + return &xehpc_q_h128; + } if (thin_q) return &xehpc_h128_2nd; if (seq <= 32) return &xehpc_h128_s32; if (seq <= 64) return &xehpc_h128_s64; @@ -179,6 +225,11 @@ sdpa_config_t *choose_config_xehpc(int head_size, int seq, bool thin_q) { } // namespace +const bool kq_common_scales = false; +const bool kq_common_zp = false; +const bool vs_common_scales = false; +const bool vs_common_zp = false; + std::mutex SDPAKernelMicro::m; void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Package& gemm_kq, micro::Package& gemm_vs, bool is_prefill) const { @@ -200,15 +251,18 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Packag sdpa_config_t *config = nullptr; bool thin_q = (!n_queries.is_dynamic && (n_queries.v <= 16)) || !is_prefill; + bool is_quantized = (K.GetDType() == Datatype::UINT8 || K.GetDType() == Datatype::INT8) || + (V.GetDType() == Datatype::UINT8 || V.GetDType() == Datatype::INT8); + switch (params.engineInfo.arch) { case gpu_arch::xe_hpg: { - config = choose_config_xehpg(static_cast(head_size), static_cast(n_keys.v), thin_q); + config = choose_config_xehpg(static_cast(head_size), static_cast(n_keys.v), thin_q, is_quantized); break; } case gpu_arch::xe_hpc: case gpu_arch::xe2: case gpu_arch::xe3: { - config = choose_config_xehpc(static_cast(head_size), static_cast(n_keys.v), thin_q); + config = choose_config_xehpc(static_cast(head_size), static_cast(n_keys.v), thin_q, is_quantized); break; } default: break; @@ -224,13 +278,47 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Packag /* Set up GEMMProblem structure for first GEMM: K^T * Q */ micro::GEMMProblem problem; - problem.Ta = problem.Ta_ext = convert_type(K.GetDType()); - problem.Tb = problem.Tb_ext = convert_type(Q.GetDType()); + problem.Ta_ext = convert_type(K.GetDType()); + problem.Tb_ext = convert_type(Q.GetDType()); + + problem.Ta = problem.Tb = micro::Type::f16; problem.Tc = problem.Tc_ext = micro::Type::f32; problem.Ts = problem.Tc; auto problem_kq = problem; problem_kq.A.layout = micro::MatrixLayout::T; + + /* Set up microkernel options */ + micro::GEMMProtocol::Options opts_kq; + opts_kq.localB = true; + opts_kq.slmPtr = true; + + if (params.conf.is_kv_compressed && !kq_common_scales) { + const auto scale_dt = convert_type(params.key_cache_comp_scale.GetDType()); + problem_kq.Ta_scale = scale_dt; + problem_kq.A_scale.alignment = micro::data_type_size(scale_dt); + + problem_kq.A_scale.layout = micro::MatrixLayout::T; + problem_kq.aScale2D = true; + } + + if (params.conf.is_kv_compressed && params.conf.use_asymmetric_quantization) { + const auto zp_dt = convert_type(params.key_cache_comp_zp.GetDType()); + problem_kq.Tao = zp_dt; + problem_kq.AO.alignment = micro::data_type_size(zp_dt); + problem_kq.AO.layout = micro::MatrixLayout::T; + problem_kq.aoPtrDims = kq_common_zp ? 0 : 2; + problem_kq.aOffset = micro::ABOffset::Calc; + } + + if (params.conf.is_kv_compressed) { + problem_kq.aqGroupM = 1; + problem_kq.aqGroupK = (kq_common_scales || kq_common_zp) ? 1 : params.conf.head_size; + } + + opts_kq.scaleA = params.conf.is_kv_compressed && !kq_common_scales; + opts_kq.offsetA = params.conf.is_kv_compressed && params.conf.use_asymmetric_quantization; + problem_kq.B.layout = micro::MatrixLayout::Pr; problem_kq.C.layout = micro::MatrixLayout::T; problem_kq.A.setAlignment(micro::alignment_for_ld(head_size * problem.Ta)); @@ -253,18 +341,49 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Packag reqs_kq.push_back(micro::StrategyRequirement::WGM == config->wg_m_kq); reqs_kq.push_back(micro::StrategyRequirement::WGN == config->wg_n_kq); - /* Set up microkernel options */ - micro::GEMMProtocol::Options opts_kq; - opts_kq.localB = true; - opts_kq.slmPtr = true; - /* Ask microkernel provider for microkernel */ - gemm_kq = micro::select_gemm_microkernel(opts_kq, hw_info, sizes, problem_kq, reqs_kq); + try { + gemm_kq = micro::select_gemm_microkernel(opts_kq, hw_info, sizes, problem_kq, reqs_kq); + } catch (const std::runtime_error &ex) { + GPU_DEBUG_TRACE_DETAIL << "Can't create KQ sdpa_micro kernel: " << ex.what() << "\n"; + throw; + } + + /* Set up microkernel options */ + micro::GEMMProtocol::Options opts_vs; + opts_vs.localB = true; + opts_vs.slmPtr = true; /* Update for second GEMM: V*S */ auto problem_vs = problem; - problem_vs.Ta = problem_vs.Ta_ext = convert_type(V.GetDType()); + problem_vs.Ta_ext = convert_type(V.GetDType()); problem_vs.A.layout = micro::MatrixLayout::N; + + if (params.conf.is_kv_compressed && !vs_common_scales) { + auto scale_dt = convert_type(params.value_cache_comp_scale.GetDType()); + problem_vs.Ta_scale = scale_dt; + problem_vs.A_scale.alignment = micro::data_type_size(scale_dt); + problem_vs.A_scale.layout = micro::MatrixLayout::N; + problem_vs.aScale2D = true; + } + + if (params.conf.is_kv_compressed && params.conf.use_asymmetric_quantization) { + auto zp_dt = convert_type(params.value_cache_comp_zp.GetDType()); + problem_vs.Tao = zp_dt; + problem_vs.AO.alignment = micro::data_type_size(zp_dt); + problem_vs.AO.layout = micro::MatrixLayout::N; + problem_vs.aoPtrDims = vs_common_zp ? 0 : 2; + problem_vs.aOffset = micro::ABOffset::Calc; + } + + if (params.conf.is_kv_compressed) { + problem_vs.aqGroupM = (vs_common_scales || vs_common_zp) ? 1 : micro::rnd_up_pow2(params.conf.head_size); + problem_vs.aqGroupK = 1; + } + + opts_vs.scaleA = params.conf.is_kv_compressed && !vs_common_scales; + opts_vs.offsetA = params.conf.is_kv_compressed && params.conf.use_asymmetric_quantization; + problem_vs.B.layout = micro::MatrixLayout::Pr; problem_vs.C.layout = micro::MatrixLayout::N; problem_vs.A.setAlignment(micro::alignment_for_ld(head_size * problem.Ta)); @@ -281,20 +400,23 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Packag reqs_vs.push_back(micro::StrategyRequirement::WGM == config->wg_m_vs); reqs_vs.push_back(micro::StrategyRequirement::WGN == config->wg_n_vs); - micro::GEMMProtocol::Options opts_vs; - opts_vs.localB = true; - opts_vs.slmPtr = true; - auto adjust_vs = [](micro::GEMMStrategy &strategy) { /* Enable dpasw */ strategy.dpasw |= strategy.fused; }; /* Ask microkernel provider for microkernel */ - gemm_vs = micro::select_gemm_microkernel(opts_vs, hw_info, sizes, problem_vs, reqs_vs, adjust_vs); + try { + gemm_vs = micro::select_gemm_microkernel(opts_vs, hw_info, sizes, problem_vs, reqs_vs, adjust_vs); + } catch (const std::runtime_error &ex) { + GPU_DEBUG_TRACE_DETAIL << "Can't create VS sdpa_micro kernel: " << ex.what() << "\n"; + throw; + } } ParamsKey SDPAKernelMicro::GetSupportedKey() const { ParamsKey k; + k.EnableInputDataType(Datatype::INT8); + k.EnableInputDataType(Datatype::UINT8); k.EnableInputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::F16); @@ -344,9 +466,6 @@ bool SDPAKernelMicro::Validate(const Params& p) const { if (params.conf.head_size > 256) return false; - if (params.conf.is_kv_compressed) - return false; - // Do not use sdpa_micro kernel with a scalar-value mask if (params.inputs.size() > 3 && !params.inputs[3].is_dynamic() && params.inputs[3].LogicalSize() == 1) return false; @@ -388,6 +507,52 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, const m jit.AddConstant(MakeJitConstant("TRANSPOSE_K", false)); + jit.AddConstant(MakeJitConstant("QRY_DATA_T", toCLType(Q.GetDType()))); + jit.AddConstant(MakeJitConstant("KEY_DATA_T", toCLType(K.GetDType()))); + jit.AddConstant(MakeJitConstant("VAL_DATA_T", toCLType(V.GetDType()))); + + if (params.conf.is_kv_compressed) { + jit.AddConstant(MakeJitConstant("KV_COMPRESSED", 1)); + jit.AddConstant(MakeJitConstant("KEY_ATTR_SCALES_DATA_T", toCLType(params.key_cache_comp_scale.GetDType()))); + jit.AddConstant(MakeJitConstant("VAL_ATTR_SCALES_DATA_T", toCLType(params.value_cache_comp_scale.GetDType()))); + + if (params.conf.use_asymmetric_quantization) { + jit.AddConstant(MakeJitConstant("KEY_ATTR_ZP_DATA_T", toCLType(params.key_cache_comp_zp.GetDType()))); + jit.AddConstant(MakeJitConstant("VAL_ATTR_ZP_DATA_T", toCLType(params.value_cache_comp_zp.GetDType()))); + } + } + + auto elems_per_byte = [](Datatype dt) { + switch (dt) { + case Datatype::UINT4: + case Datatype::INT4: + return 2; + default: + return 1; + } + }; + + jit.AddConstant(MakeJitConstant("KEY_ELEMENTS_PER_BYTE", elems_per_byte(params.inputs[1].GetDType()))); + jit.AddConstant(MakeJitConstant("VAL_ELEMENTS_PER_BYTE", elems_per_byte(params.inputs[2].GetDType()))); + + if (params.conf.is_kv_compressed) { + int kq_scale_mask = (static_cast(params.conf.is_kv_compressed) << 1) | static_cast(kq_common_scales); + int vs_scale_mask = (static_cast(params.conf.is_kv_compressed) << 1) | static_cast(vs_common_scales); + jit.AddConstant(MakeJitConstant("KEY_SCALES", kq_scale_mask)); + jit.AddConstant(MakeJitConstant("VAL_SCALES", vs_scale_mask)); + jit.AddConstant(MakeJitConstant("KEY_GROUP_SIZE", params.conf.head_size)); + jit.AddConstant(MakeJitConstant("VAL_GROUP_SIZE", params.conf.head_size)); + + if (params.conf.use_asymmetric_quantization) { + int kq_zp_mask = (static_cast(params.conf.use_asymmetric_quantization) << 1) | static_cast(kq_common_zp); + int vs_zp_mask = (static_cast(params.conf.use_asymmetric_quantization) << 1) | static_cast(vs_common_zp); + jit.AddConstant(MakeJitConstant("KEY_ZERO_POINTS", kq_zp_mask)); + jit.AddConstant(MakeJitConstant("VAL_ZERO_POINTS", vs_zp_mask)); + jit.AddConstant(MakeJitConstant("KEY_ZP_ELEMENTS_PER_BYTE", elems_per_byte(params.key_cache_comp_zp.GetDType()))); + jit.AddConstant(MakeJitConstant("VAL_ZP_ELEMENTS_PER_BYTE", elems_per_byte(params.value_cache_comp_zp.GetDType()))); + } + } + int tile_k = gemm_kq.getSetting("wg_tile_m"); int tile_q = gemm_kq.getSetting("wg_tile_n"); int tile_v = gemm_vs.getSetting("wg_tile_m"); @@ -470,6 +635,18 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, const m jit.Merge(unit_parameters("VAL")); jit.Merge(unit_parameters("DST")); + if (params.conf.is_kv_compressed) { + jit.AddConstant(MakeJitConstant("KEY_SCALE", params.key_cache_comp_scale)); + jit.AddConstant(MakeJitConstant("VAL_SCALE", params.value_cache_comp_scale)); + + const std::vector default_order = { 0, 1, 2, 3 }; + jit.Merge(convert_strides("KEY_COMP", "KEY_SCALE", default_order)); + jit.Merge(convert_strides("VAL_COMP", "VAL_SCALE", default_order)); + + jit.Merge(unit_parameters("KEY_COMP")); + jit.Merge(unit_parameters("VAL_COMP")); + } + return jit; } @@ -521,6 +698,17 @@ clKernelData SDPAKernelMicro::get_kernel_data(const sdpa_params& params, bool is kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 1}); // K kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 2}); // Q + if (params.conf.is_kv_compressed) { + uint32_t input_idx = static_cast(params.inputs.size()); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, input_idx + 0}); // K scales + if (params.conf.use_asymmetric_quantization) + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, input_idx + 2}); // K zp + + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, input_idx + 1}); // V scales + if (params.conf.use_asymmetric_quantization) + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, input_idx + 3}); // V zp + } + const auto& Q = params.inputs[0]; const auto& K = params.inputs[1]; diff --git a/src/plugins/intel_gpu/src/kernel_selector/micro_utils.hpp b/src/plugins/intel_gpu/src/kernel_selector/micro_utils.hpp index c6b0e031a027e8..2d28caec5694af 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/micro_utils.hpp +++ b/src/plugins/intel_gpu/src/kernel_selector/micro_utils.hpp @@ -20,12 +20,14 @@ #include "gpu/intel/microkernels/package.hpp" #include "gpu/intel/jit/gemm/include/microkernel_provider.hpp" #include "gpu/intel/microkernels/shim.hpp" +#include "common/utils.hpp" namespace micro { using Package = dnnl::impl::gpu::intel::micro::Package; using HWInformation = dnnl::impl::gpu::intel::jit::HWInformation; using GEMMProblem = dnnl::impl::gpu::intel::jit::GEMMProblem; +using ABOffset = dnnl::impl::gpu::intel::jit::ABOffset; using GEMMStrategy = dnnl::impl::gpu::intel::jit::GEMMStrategy; using GEMMProtocol = dnnl::impl::gpu::intel::micro::GEMMProtocol; using MatrixLayout = dnnl::impl::gpu::intel::jit::MatrixLayout; @@ -36,6 +38,8 @@ using ShimOptions = dnnl::impl::gpu::intel::micro::ShimOptions; using HostLanguage = dnnl::impl::gpu::intel::micro::HostLanguage; using Setting = dnnl::impl::gpu::intel::micro::Setting; +using dnnl::impl::utils::rnd_up_pow2; + // Wrapper for Package which is used in clKernelData with forward declaration // to avoid including this header in many places in plugin // which may cause symbols conflicts with oneDNN @@ -77,6 +81,10 @@ static inline int alignment_for_ld(int ld) { return dnnl::impl::gpu::intel::jit::alignmentForLD(ld); } +static inline uint8_t data_type_size(micro::Type dt) { + return uint8_t(dnnl::impl::types::data_type_size(micro::Type(dt).get_dnnl_type())); +} + } // namespace micro #undef UNUSED diff --git a/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.cpp b/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.cpp index 6a50a55e619fc9..c63a0b27f38577 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.cpp @@ -127,17 +127,18 @@ std::shared_ptr class KVCacheCompressionMatcher : public ov::pass::MatcherPass { public: OPENVINO_MATCHER_PASS_RTTI("KVCacheCompressionMatcher"); - KVCacheCompressionMatcher(ov::element::Type compression_dt); + KVCacheCompressionMatcher(ov::element::Type compression_dt, bool supports_immad); }; -KVCacheCompressionMatcher::KVCacheCompressionMatcher(ov::element::Type compression_dt) { +KVCacheCompressionMatcher::KVCacheCompressionMatcher(ov::element::Type compression_dt, bool supports_immad) { using namespace ov::pass::pattern; if (compression_dt != element::i8 && compression_dt != element::u8) return; const auto quantization_type = ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric; - const auto output_storage_type = ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP; + const auto output_storage_type = supports_immad ? ov::op::internal::DynamicQuantize::OutputStorageType::Planar + : ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP; bool combine_scales_and_zp = output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP; GPU_DEBUG_LOG << "KV-cache compression configuration: " @@ -219,7 +220,7 @@ KVCacheCompressionMatcher::KVCacheCompressionMatcher(ov::element::Type compressi config.output_storage_type = output_storage_type; if (config.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric) - config.zp_dt = query_node->get_output_element_type(0); + config.zp_dt = supports_immad ? element::i8 : query_node->get_output_element_type(0); key_past_rv_node = update_past_read_value(key_past_rv_node, config); value_past_rv_node = update_past_read_value(value_past_rv_node, config); @@ -284,8 +285,8 @@ bool KVCacheCompression::run_on_model(const std::shared_ptr& m) { return pass::GraphRewrite::run_on_model(m); } -KVCacheCompression::KVCacheCompression(ov::element::Type compression_dt) { - add_matcher(compression_dt); +KVCacheCompression::KVCacheCompression(ov::element::Type compression_dt, bool supports_immad) { + add_matcher(compression_dt, supports_immad); } } // namespace intel_gpu diff --git a/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.hpp b/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.hpp index 036fdb78914891..f4a930686520ba 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.hpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.hpp @@ -32,8 +32,9 @@ namespace intel_gpu { class KVCacheCompression : public ov::pass::GraphRewrite { public: + OPENVINO_GRAPH_REWRITE_RTTI("KVCacheCompression"); - KVCacheCompression(ov::element::Type compression_dt); + KVCacheCompression(ov::element::Type compression_dt, bool supports_immad); bool run_on_model(const std::shared_ptr& m) override; }; diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp index 6721d0f9ebd608..908732dc357222 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp @@ -191,7 +191,8 @@ std::vector shape_infer(const KVCacheCompressed* op, auto quantized_data_shapes = ov::op::internal::DynamicQuantize::shape_infer(&dq_op, { input_shapes[1] }); - const auto scales_concat_axis = 2; + const auto concat_axis = ov::util::normalize(op->get_concat_axis(), input_shapes[0].size()); + const auto scales_concat_axis = op->get_quantization_attrs().scales_zp_output_order[concat_axis]; ov::PartialShape compression_scale_shape = input_shapes[3]; compression_scale_shape[scales_concat_axis] += quantized_data_shapes[1][scales_concat_axis]; out_shapes[2] = compression_scale_shape; diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 9ec31bcca46973..4702aa02aaf571 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -965,7 +965,7 @@ void TransformationsPipeline::apply(std::shared_ptr func) { manager.register_pass(); auto kv_cache_compression_dt = config.get_property(ov::hint::kv_cache_precision); - manager.register_pass(kv_cache_compression_dt); + manager.register_pass(kv_cache_compression_dt, device_info.supports_immad); manager.register_pass(); diff --git a/src/plugins/intel_gpu/src/runtime/execution_config.cpp b/src/plugins/intel_gpu/src/runtime/execution_config.cpp index 3b1376d19b4fea..3e1ddda5c33c34 100644 --- a/src/plugins/intel_gpu/src/runtime/execution_config.cpp +++ b/src/plugins/intel_gpu/src/runtime/execution_config.cpp @@ -235,7 +235,7 @@ void ExecutionConfig::update_specific_default_properties(const cldnn::device_inf return; specific_default_properties_is_set = true; - // Enable KV-cache compression by default for non-systolic platforms + // Enable KV-cache compression by default for non-systolic platforms MFDNN-11755 if (get_property(ov::hint::kv_cache_precision) == ov::element::undefined && !info.supports_immad) { set_property(ov::hint::kv_cache_precision(ov::element::i8)); } diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/dynamic_quantize_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/dynamic_quantize_gpu_test.cpp index fe9b917bee9aef..3bf1a9e937c37c 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/dynamic_quantize_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/dynamic_quantize_gpu_test.cpp @@ -71,10 +71,14 @@ class dynamic_quantization_gpu_tests: public ::testing::Test { dq_config.scales_zp_output_order.emplace_back(3); dq_config.output_storage_type = storage_type; + bool has_zp_output = dq_config.quantization_type == QuantizationType::Asymmetric && + dq_config.output_storage_type == OutputStorageType::Planar; + auto reorder_1 = reorder("reorder_1", input_info("input"), layout{ input_ps, data_types::f16, format::bfyx }); auto dyn_quan_prim = dynamic_quantize("dyn_quan_prim", input_info("reorder_1"), dq_config); auto reorder_data = reorder("reorder_data", input_info("dyn_quan_prim", 0), layout{ input_ps, data_types::f16, format::bfyx }); auto reorder_scale = reorder("reorder_scale", input_info("dyn_quan_prim", 1), layout{ scales_ps, data_types::f16, format::bfyx }); + auto reorder_zp = reorder("reorder_zp", input_info("dyn_quan_prim", 2), layout{ scales_ps, data_types::f16, format::bfyx }); // Implemented dynamic quantize kernel auto get_ref_results = [&]() { @@ -86,6 +90,9 @@ class dynamic_quantization_gpu_tests: public ::testing::Test { reorder_scale ); + if (has_zp_output) + topology.add(reorder_zp); + auto config = get_test_default_config(engine); config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); config.set_property(ov::intel_gpu::optimize_data(true)); @@ -98,19 +105,27 @@ class dynamic_quantization_gpu_tests: public ::testing::Test { auto outputs = network.execute(); - auto output_layout = outputs.begin()->second.get_layout(); - auto output_mem = outputs.begin()->second.get_memory(); + std::vector output_buffers; + for (const auto& output : outputs) { + auto output_layout = output.second.get_layout(); + auto output_mem = output.second.get_memory(); + output_buffers.push_back(engine.reinterpret_buffer(*output_mem, output_layout)); + } - return engine.reinterpret_buffer(*output_mem, output_layout); + return output_buffers; }; topology topology( input_layout("input", in_layout_f32), reorder_1, dyn_quan_prim, - reorder_data + reorder_data, + reorder_scale ); + if (has_zp_output) + topology.add(reorder_zp); + auto config = get_test_default_config(engine); config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); config.set_property(ov::intel_gpu::optimize_data(true)); @@ -126,23 +141,28 @@ class dynamic_quantization_gpu_tests: public ::testing::Test { auto outputs = network->execute(); - auto output_mem = outputs.begin()->second.get_memory(); - cldnn::mem_lock output_ptr (output_mem, get_test_stream()); - - auto ref_output_mem = get_ref_results(); - cldnn::mem_lock output_ptr_ref (ref_output_mem, get_test_stream()); - size_t count = 0; - float max_diff = 0.f; - float avg = 0.f; - for (size_t i = 0; i < output_ptr_ref.size(); ++i) { - auto abs_diff = std::abs(output_ptr_ref[i] - output_ptr[i]); - if (max_diff < abs_diff) - max_diff = abs_diff; - avg += abs_diff; - count++; - ASSERT_LE(abs_diff, 1); + std::vector output_buffers; + for (const auto& output : outputs) { + auto output_layout = output.second.get_layout(); + auto output_mem = output.second.get_memory(); + output_buffers.push_back(engine.reinterpret_buffer(*output_mem, output_layout)); + } + + auto ref_output_buffers = get_ref_results(); + + ASSERT_EQ(ref_output_buffers.size(), output_buffers.size()); + + std::cout << "Outputs number: " << ref_output_buffers.size() << "\n"; + + for (size_t i = 0; i < ref_output_buffers.size(); i++) { + cldnn::mem_lock output_ptr(output_buffers[i], get_test_stream()); + cldnn::mem_lock output_ptr_ref(ref_output_buffers[i], get_test_stream()); + + for (size_t i = 0; i < output_ptr_ref.size(); ++i) { + auto abs_diff = std::abs(output_ptr_ref[i] - output_ptr[i]); + ASSERT_LE(abs_diff, 1); + } } - GPU_DEBUG_LOG << "---> count: " << count << ", max_diff:" << max_diff << ", avg_diff: " << (avg/count) << std::endl; } }; @@ -215,21 +235,67 @@ TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_batched_reorde data_types::i8, data_types::undefined, OutputStorageType::Planar, "dynamic_quantize_gpu_kv_cache"); } -TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_asym) { +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_asym_planar) { + this->test_dynamic_quantization(false, {-1, 8, -1, 96}, {1, 8, 1, 96}, QuantizationType::Asymmetric, UINT64_MAX, + data_types::i8, data_types::f16, OutputStorageType::Planar, "dynamic_quantize_gpu_kv_cache"); +} + +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_batched_asym_planar) { + this->test_dynamic_quantization(false, {-1, 4, -1, 64}, {1, 4, 35, 64}, QuantizationType::Asymmetric, UINT64_MAX, + data_types::i8, data_types::f16, OutputStorageType::Planar, "dynamic_quantize_gpu_kv_cache"); +} + +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_reordered_asym_planar) { + this->test_dynamic_quantization(false, {-1, -1, 8, 96}, {1, 1, 8, 96}, QuantizationType::Asymmetric, UINT64_MAX, + data_types::i8, data_types::f16, OutputStorageType::Planar, "dynamic_quantize_gpu_kv_cache"); +} + +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_batched_reordered_asym_planar) { + this->test_dynamic_quantization(false, {-1, -1, 4, 64}, {1, 35, 4, 64}, QuantizationType::Asymmetric, UINT64_MAX, + data_types::i8, data_types::f16, OutputStorageType::Planar, "dynamic_quantize_gpu_kv_cache"); +} + +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_asym_interleaved) { this->test_dynamic_quantization(false, {-1, 8, -1, 96}, {1, 8, 1, 96}, QuantizationType::Asymmetric, UINT64_MAX, data_types::i8, data_types::f16, OutputStorageType::InterleavedScalesZP, "dynamic_quantize_gpu_kv_cache"); } -TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_batched_asym) { +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_batched_asym_interleaved) { this->test_dynamic_quantization(false, {-1, 4, -1, 64}, {1, 4, 35, 64}, QuantizationType::Asymmetric, UINT64_MAX, data_types::i8, data_types::f16, OutputStorageType::InterleavedScalesZP, "dynamic_quantize_gpu_kv_cache"); } -TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_reordered_asym) { +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_reordered_asym_interleaved) { this->test_dynamic_quantization(false, {-1, -1, 8, 96}, {1, 1, 8, 96}, QuantizationType::Asymmetric, UINT64_MAX, data_types::i8, data_types::f16, OutputStorageType::InterleavedScalesZP, "dynamic_quantize_gpu_kv_cache"); } + +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_batched_reordered_asym_interleaved) { + this->test_dynamic_quantization(false, {-1, -1, 4, 64}, {1, 35, 4, 64}, QuantizationType::Asymmetric, UINT64_MAX, + data_types::i8, data_types::f16, OutputStorageType::InterleavedScalesZP, "dynamic_quantize_gpu_kv_cache"); +} + +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_asym_planar_i8_zp) { + this->test_dynamic_quantization(false, {-1, 8, -1, 32}, {1, 8, 1, 32}, QuantizationType::Asymmetric, UINT64_MAX, + data_types::i8, data_types::i8, OutputStorageType::Planar, "dynamic_quantize_gpu_kv_cache"); +} + +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_batched_asym_planar_i8_zp) { + this->test_dynamic_quantization(false, {-1, 4, -1, 64}, {1, 4, 35, 64}, QuantizationType::Asymmetric, UINT64_MAX, + data_types::i8, data_types::i8, OutputStorageType::Planar, "dynamic_quantize_gpu_kv_cache"); +} + +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_reordered_asym_planar_i8_zp) { + this->test_dynamic_quantization(false, {-1, -1, 8, 96}, {1, 1, 8, 96}, QuantizationType::Asymmetric, UINT64_MAX, + data_types::i8, data_types::i8, OutputStorageType::Planar, "dynamic_quantize_gpu_kv_cache"); +} + +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_batched_reordered_asym_planar_i8_zp) { + this->test_dynamic_quantization(false, {-1, -1, 4, 64}, {1, 35, 4, 64}, QuantizationType::Asymmetric, UINT64_MAX, + data_types::i8, data_types::i8, OutputStorageType::Planar, "dynamic_quantize_gpu_kv_cache"); +} + TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_inner_most_dim_zero_values_asym) { this->test_dynamic_quantization(false, {-1, 8, -1, 128}, {1, 8, 52, 128}, QuantizationType::Asymmetric, UINT64_MAX, data_types::i8, data_types::f16, OutputStorageType::InterleavedScalesZP, "dynamic_quantize_gpu_kv_cache", true); diff --git a/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_compression.cpp b/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_compression.cpp index 67123f1d84cfe7..ca101d1b5d2c6c 100644 --- a/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_compression.cpp +++ b/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_compression.cpp @@ -93,7 +93,7 @@ TEST_F(TransformationTestsF, KVCacheCompression) { ov::ResultVector results{ result }; model = std::make_shared(results, params); - manager.register_pass(ov::element::i8); + manager.register_pass(ov::element::i8, false); } { ov::op::internal::DynamicQuantize::Attributes dq_config; @@ -244,7 +244,7 @@ TEST_F(TransformationTestsF, KVCacheCompressionWithInitializers) { ov::ResultVector results{ result }; model = std::make_shared(results, params); - manager.register_pass(ov::element::i8); + manager.register_pass(ov::element::i8, false); } { ov::op::internal::DynamicQuantize::Attributes dq_config;