Skip to content

Commit

Permalink
WIP: [GPU] Debug accuracy issue
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Dec 20, 2024
1 parent 5df2d53 commit ead8d8a
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/debug_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ void dump(memory::ptr mem, stream& stream, std::ofstream& file_stream, bool dump
file_stream << "shape: " << size.to_string() << " ";
file_stream << "(count: " << size.count()
<< ", addr: " << mem->buffer_ptr()
<< ", original dt: " << mem->get_layout().data_type
<< ", original format: " << cldnn::fmt_to_str(mem->get_layout().format) << ")"
<< (dump_raw ? " raw data" : "") << std::endl;
} else {
file_stream << "shape: " << tmp_size.to_string() << " ";
file_stream << "(count: " << tmp_size.count()
<< ", addr: " << mem->buffer_ptr()
<< ", original dt: " << mem->get_layout().data_type
<< ", original format: " << cldnn::fmt_to_str(mem->get_layout().format)
<< ", original shape: " << size.to_string() << ")"
<< (dump_raw ? " raw data" : "") << std::endl;
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
auto comp_scale_past_layout = impl_param.input_layouts[input_idx];
auto comp_scale_present_layout = impl_param.output_layouts[output_idx];

GPU_DEBUG_TRACE_DETAIL << "Update params, input: " << comp_scale_past_layout.to_short_string() << ", output: " << comp_scale_past_layout << "\n";

params.inputs.resize(inputs_count);
params.inputs[0] = convert_data_tensor(comp_scale_past_layout);
params.outputs[0] = convert_data_tensor(comp_scale_present_layout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,11 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
}

if (desc->is_kv_compressed) {
GPU_DEBUG_TRACE_DETAIL << "Update scale layout: " << impl_param.get_input_layout(data_inputs_num) << "\n";
params.key_cache_comp_scale = convert_data_tensor(impl_param.get_input_layout(data_inputs_num));
GPU_DEBUG_TRACE_DETAIL << "Updated scale tensor pad: " << params.key_cache_comp_scale.Y().pad.before << " "
<< params.key_cache_comp_scale.Y().pad.after << "\n";

params.value_cache_comp_scale = convert_data_tensor(impl_param.get_input_layout(data_inputs_num + 1));

if (has_zp_input_buffers) {
Expand Down
16 changes: 15 additions & 1 deletion src/plugins/intel_gpu/src/graph/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,21 @@ int32_t kv_cache_inst::get_prealloc_iter_num() {
// iteration.
// - Therfore, to avoid this situation where the allocation and copying occurs simutaneously for all the kv_cache_insts,
// we assigned different prealloc-size for each kv cache so that we could prevent a memory peak
return 128 + kv_cache_id % 64;
int KV_STRIDE = 0;
if (const auto env_var = std::getenv("KV_STRIDE")) {
std::istringstream ss(env_var);
ss >> KV_STRIDE;
static bool print_once = true;
if (print_once) {
std::cout << ">>> KV_STRIDE = " << KV_STRIDE << "\n";
print_once = false;
}
}
if (KV_STRIDE != 0) {
return KV_STRIDE;
}

return 128 + kv_cache_id % 64;;
}

void kv_cache_inst::update_shape_info_tensor(const kernel_impl_params& params) {
Expand Down
90 changes: 86 additions & 4 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_micro.cl
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
#endif
int d, int k, int q
#ifdef KV_COMPRESSED
, const global KEY_ATTR_SCALES_DATA_T *K_scales
, const global KEY_ATTR_ZP_DATA_T *K_zp
, const global VAL_ATTR_SCALES_DATA_T *V_scales
, const global VAL_ATTR_ZP_DATA_T *V_zp
, global KEY_ATTR_SCALES_DATA_T *K_scales
, global KEY_ATTR_ZP_DATA_T *K_zp
, global VAL_ATTR_SCALES_DATA_T *V_scales
, global VAL_ATTR_ZP_DATA_T *V_zp
#endif
) {
uint sg_ij = sub_group_broadcast(get_local_id(1), 0);
Expand All @@ -164,6 +164,87 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG

uint wg_j0 = get_group_id(0) * ugemm_kq_wg_tile_n;





// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
// for (int j = 0; j < 32; j++) {
// #ifdef KV_COMPRESSED
// for (int i = k; i < k + KEY_SCALE_PAD_AFTER_SIZE_Y; i++) {
// K_scales[(j * (k + KEY_SCALE_PAD_AFTER_SIZE_Y)) + i] = 0;
// K_zp[(j * (k + KEY_SCALE_PAD_AFTER_SIZE_Y)) + i] = 0;
// }

// for (int i = k; i < k + VAL_SCALE_PAD_AFTER_SIZE_Y; i++) {
// V_scales[(j * (k + VAL_SCALE_PAD_AFTER_SIZE_Y)) + i] = 0;
// V_zp[(j * (k + VAL_SCALE_PAD_AFTER_SIZE_Y)) + i] = 0;
// }
// #endif
// }

// printf("h=0 key scales[%p]={", K_scales);
// for (int i = 0; i < max(k, k + KEY_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%f, ", K_scales[i]);
// printf("total_num=%d, pad=%d\n", k, KEY_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=0 key zp[%p]={", K_zp);
// for (int i = 0; i < max(k, k + KEY_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%d, ", K_zp[i]);
// printf("total_num=%d, pad=%d\n", k, KEY_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=0 val scales[%p]={", V_scales);
// for (int i = 0; i < max(k, k + VAL_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%f, ", V_scales[i]);
// printf("total_num=%d, pad=%d\n", k, VAL_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=0 key zp[%p]={", V_zp);
// for (int i = 0; i < max(k, k + VAL_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%d, ", V_zp[i]);
// printf("total_num=%d, pad=%d\n", k, VAL_SCALE_PAD_AFTER_SIZE_Y);


// printf("h=1 key scales[%p]={", K_scales);
// for (int i = 0; i < max(k, k + KEY_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%f, ", K_scales[k + KEY_SCALE_PAD_AFTER_SIZE_Y + i]);
// printf("total_num=%d, pad=%d\n", k, KEY_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=1 key zp[%p]={", K_zp);
// for (int i = 0; i < max(k, k + KEY_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%d, ", K_zp[k + KEY_SCALE_PAD_AFTER_SIZE_Y + i]);
// printf("total_num=%d, pad=%d\n", k, KEY_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=1 val scales[%p]={", V_scales);
// for (int i = 0; i < max(k, k + VAL_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%f, ", V_scales[k + VAL_SCALE_PAD_AFTER_SIZE_Y + i]);
// printf("total_num=%d, pad=%d\n", k, VAL_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=1 key zp[%p]={", V_zp);
// for (int i = 0; i < max(k, k + VAL_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%d, ", V_zp[k + VAL_SCALE_PAD_AFTER_SIZE_Y + i]);
// printf("total_num=%d, pad=%d\n", k, VAL_SCALE_PAD_AFTER_SIZE_Y);


// printf("h=31 key scales[%p]={", K_scales);
// for (int i = 0; i < max(k, k + KEY_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%f, ", K_scales[(31 * (k + KEY_SCALE_PAD_AFTER_SIZE_Y)) + i]);
// printf("total_num=%d, pad=%d\n", k, KEY_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=31 key zp[%p]={", K_zp);
// for (int i = 0; i < max(k, k + KEY_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%d, ", K_zp[(31 * (k + KEY_SCALE_PAD_AFTER_SIZE_Y)) + i]);
// printf("total_num=%d, pad=%d\n", k, KEY_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=31 val scales[%p]={", V_scales);
// for (int i = 0; i < max(k, k + VAL_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%f, ", V_scales[(31 * (k + VAL_SCALE_PAD_AFTER_SIZE_Y)) + i]);
// printf("total_num=%d, pad=%d\n", k, VAL_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=31 key zp[%p]={", V_zp);
// for (int i = 0; i < max(k, k + VAL_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%d, ", V_zp[(31 * (k + VAL_SCALE_PAD_AFTER_SIZE_Y)) + i]);
// printf("total_num=%d, pad=%d\n", k, VAL_SCALE_PAD_AFTER_SIZE_Y);
// }
/* Leading dimension for matrices */
uint ldk = TRANSPOSE_K ? KEY_S3 : KEY_S2;
uint ldq = QRY_S2;
Expand Down Expand Up @@ -307,6 +388,7 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG

#if WITH_ATTN_MASK
mask_tile_type mask_tile;
// tile_load_t(&mask_tile, msk, q, k, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
tile_load_t(&mask_tile, msk, q, k, q, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,10 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Packag
opts_kq.scaleA = params.conf.is_kv_compressed && !kq_common_scales;
opts_kq.offsetA = params.conf.is_kv_compressed && params.conf.use_asymmetric_quantization;

// auto key_dt_size = micro::data_type_size(convert_type(params.inputs[1].GetDType()));
problem_kq.B.layout = micro::MatrixLayout::Pr;
problem_kq.C.layout = micro::MatrixLayout::T;
// problem_kq.A.setAlignment(micro::alignment_for_ld(head_size * key_dt_size));
problem_kq.A.setAlignment(micro::alignment_for_ld(head_size * problem.Ta));
problem_kq.B.setAlignment(64); // Q is packed in VNNI format in SLM
problem_kq.B.crosspack = 2;
Expand Down Expand Up @@ -337,8 +339,10 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Packag
opts_vs.scaleA = params.conf.is_kv_compressed && !vs_common_scales;
opts_vs.offsetA = params.conf.is_kv_compressed && params.conf.use_asymmetric_quantization;

// auto val_dt_size = micro::data_type_size(convert_type(params.inputs[2].GetDType()));
problem_vs.B.layout = micro::MatrixLayout::Pr;
problem_vs.C.layout = micro::MatrixLayout::N;
// problem_vs.A.setAlignment(micro::alignment_for_ld(head_size * val_dt_size));
problem_vs.A.setAlignment(micro::alignment_for_ld(head_size * problem.Ta));
problem_vs.B.setAlignment(64); // S is packed in SLM
problem_vs.B.crosspack = 16;
Expand Down Expand Up @@ -536,6 +540,12 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, const m
// TODO: Causes accuracy drop for static SD model. Enable back once the issue is resolved
// if (lda % 4 == 0 && v_full)
// jit.AddConstant(MakeJitConstant("BLOCK_A", 1));
// if (params.inputs.size() > 3 && !params.inputs[3].is_dynamic()) {
// auto ldmsk = params.inputs[3].X().v * params.inputs[3].ElementSize();
// if (ldmsk % 4 == 0)
// jit.AddConstant(MakeJitConstant("BLOCK_MSK", 1));
// }
// if (ldmsk % 4 == 0) kernel_ctx.define_int("BLOCK_MSK", 1);
jit.AddConstant(MakeJitConstant("REMAINDER_Q", !q_full));
} else if (params.engineInfo.arch >= gpu_arch::xe_hpc) {
auto vbytes = n_values.v * V.ElementSize();
Expand Down Expand Up @@ -753,6 +763,12 @@ void SDPAKernelMicro::GetUpdateDispatchDataFunc(KernelData& kd) const {
const auto n_queries = get_seq_length(Q, prim_params.input0_order);
const auto n_keys = get_seq_length(K, prim_params.input1_order);

GPU_DEBUG_TRACE_DETAIL << "Key scale pad_before=" << prim_params.key_cache_comp_scale.Y().pad.before
<< "pad_after=" << prim_params.key_cache_comp_scale.Y().pad.after << "\n";

GPU_DEBUG_TRACE_DETAIL << "Value scale pad_before=" << prim_params.value_cache_comp_scale.Y().pad.before
<< "pad_after=" << prim_params.value_cache_comp_scale.Y().pad.after << "\n";

auto head_size = prim_params.conf.head_size;

ScalarDescriptor s_d;
Expand Down

0 comments on commit ead8d8a

Please sign in to comment.