Skip to content

Commit

Permalink
[GPU] Add optimization for per-token to reduce calculation
Browse files Browse the repository at this point in the history
Signed-off-by: Min, Byungil <[email protected]>
  • Loading branch information
byungilm committed Dec 13, 2024
1 parent d83debc commit 33e33de
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -969,10 +969,22 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)(
// Main computation loop
const uint iterations = MAIN_LOOP_ELEMENTS_COUNT / TILE_IFM_ELEMENTS_SIZE; // TILE_IFM_ELEMENTS_SIZE : (TILE_IFM * SIMD)
// Each sub-group loads 2 Batch
uint idx_sglid = (sglid * TILE_K) % TILE_IFM_ELEMENTS_SIZE; // same index for sglid 0~7 : to tile_k direction
uint batch_sglid = (sglid * TILE_K) / TILE_IFM_ELEMENTS_SIZE; // 0 to 1 : to batch direction

const uint idx_sglid = (sglid * TILE_K) % TILE_IFM_ELEMENTS_SIZE; // same index for sglid 0~7 : to tile_k direction
const uint batch_sglid = (sglid * TILE_K) / TILE_IFM_ELEMENTS_SIZE; // 0 to 1 : to batch direction
const uint scale_pitch = (TILE_IN_B_PITCH / QUANTIZE_GROUP_SIZE);

#if PER_TOKEN_SIZE_DYN_QUANTIZE
// Each token is quantized by once. So, all MAIN_LOOP_ELEMENTS_COUNT share just one quantizing variable
uint per_token_offset = input_offset / QUANTIZE_GROUP_SIZE;
unroll_for (uint bi = 0; bi < TILE_B; ++bi) {
de_quantize_scale[bi] = TO_INPUT0_TYPE(quan_var[per_token_offset * 2]);
#if COMPRESSED_WEIGHTS_INT8
activation_sum[bi] = TO_INPUT0_TYPE(quan_var[per_token_offset * 2 + 1]);
#endif
per_token_offset += scale_pitch;
}
#endif

MAKE_VECTOR_TYPE(int, TILE_B) acc_tmp[TILE_OFM] = { };
__attribute__((opencl_unroll_hint(1)))
for (uint ni = 0; ni < iterations; ++ni) {
Expand All @@ -987,7 +999,7 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)(
// Next batch
in_offset += (TILE_IN_B_PITCH * 2);

#if NUM_LOOP_IN_DYN_QUAN_GROUP == 1
#if (PER_TOKEN_SIZE_DYN_QUANTIZE == 0) && (NUM_LOOP_IN_DYN_QUAN_GROUP == 1)
de_quantize_scale[bi * 2] = TO_INPUT0_TYPE(quan_var[scale_offset * 2]);
de_quantize_scale[bi * 2 + 1] = TO_INPUT0_TYPE(quan_var[scale_offset * 2 + scale_pitch * 2]);
#if COMPRESSED_WEIGHTS_INT8
Expand All @@ -1000,12 +1012,12 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)(
#endif
}

#if NUM_LOOP_IN_DYN_QUAN_GROUP > 1
#if (PER_TOKEN_SIZE_DYN_QUANTIZE == 0) && (NUM_LOOP_IN_DYN_QUAN_GROUP > 1)
if (ni % NUM_LOOP_IN_DYN_QUAN_GROUP == 0) {
unroll_for (uint bi = 0; bi < TILE_B; ++bi) {
de_quantize_scale[bi] = TO_INPUT0_TYPE(quan_var[scale_offset * 2]);
#if COMPRESSED_WEIGHTS_INT8
activation_sum[bi] = quan_var[scale_offset * 2 + 1];
activation_sum[bi] = TO_INPUT0_TYPE(quan_var[scale_offset * 2 + 1]);
#endif
scale_offset += scale_pitch;
}
Expand Down Expand Up @@ -1205,7 +1217,7 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)(
#endif
} // Whole tile_k elements of each iteration : ki

#if DQ_DECOMPRESSION_SCALE_POST_OP && (TILE_IFM_ELEMENTS_SIZE <= DECOMPRESSION_SCALE_GROUP_SIZE)
#if (PER_TOKEN_SIZE_DYN_QUANTIZE == 0) && DQ_DECOMPRESSION_SCALE_POST_OP && (TILE_IFM_ELEMENTS_SIZE <= DECOMPRESSION_SCALE_GROUP_SIZE)
// Dynamic-quantizing group size set to same or smaller than scale group size
if ((ni % NUM_LOOP_IN_DYN_QUAN_GROUP) == (NUM_LOOP_IN_DYN_QUAN_GROUP - 1)) {
const uint ni_offset = ((ni*TILE_IFM*SIMD) / DECOMPRESSION_SCALE_GROUP_SIZE)*DECOMPRESSION_SCALE_FEATURE_PITCH;
Expand Down Expand Up @@ -1233,6 +1245,20 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)(
#endif
} // Main compute loop : ni

#if PER_TOKEN_SIZE_DYN_QUANTIZE
unroll_for (uint bi = 0; bi < TILE_B; ++bi) {
unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) {
ACCUMULATOR_TYPE ds = d_scales[fi % DECOMPRESSION_SCALE_LENGTH];
#if COMPRESSED_WEIGHTS_INT8
float modified_calc_buff = ((float)((int *)(&acc_tmp[fi]))[bi]) - ((float)(wei_zp[fi]) * activation_sum[bi]);
((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] = (convert_half)(modified_calc_buff) * ds * de_quantize_scale[bi];
#else
((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] = convert_half(((int *)(&acc_tmp[fi]))[bi]) * ds * de_quantize_scale[bi];
#endif
}
}
#endif

// =====================================================================================================================================
// Post-processing: bias, activation, fused-ops
for (uint bi = 0; bi < TILE_B; ++bi) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,11 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para
jit.AddConstant(MakeJitConstant("DYNAMIC_QUANTIZE", 1));
jit.AddConstant(MakeJitConstant("DQ_DECOMPRESSION_SCALE_POST_OP", 1));
jit.AddConstant(MakeJitConstant("QUANTIZE_GROUP_SIZE", quantize_grp_size));

if(is_per_token_dynamic_quantize(params) && quantize_grp_size == get_input_bf_size(params).second)
jit.AddConstant(MakeJitConstant("PER_TOKEN_SIZE_DYN_QUANTIZE", 1));
else
jit.AddConstant(MakeJitConstant("PER_TOKEN_SIZE_DYN_QUANTIZE", 0));
} else {
if (add_decompress_scale_post_op)
jit.AddConstant(MakeJitConstant("DECOMPRESSION_SCALE_POST_OP", 1));
Expand Down

0 comments on commit 33e33de

Please sign in to comment.