Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: [GPU] Use FP32 accumulator for QK multiplication for 2nd+ token calculation in PagedAttention #28673

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ KERNEL(pa_sdpa_opt)(
#endif

// SLM for intermediate QK results
__local OUTPUT_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE];
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE];

// SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WGs
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals[SUBGROUPS_PER_WG];
Expand Down Expand Up @@ -166,7 +166,7 @@ KERNEL(pa_sdpa_opt)(
#endif
const uint block_offset = block_indices[start_block_idx + block_num * SUBGROUPS_PER_WG] * HEAD_SIZE * KV_HEADS_NUM * SUBGROUP_SIZE + head_idx * HEAD_SIZE * SUBGROUP_SIZE;

INPUT0_TYPE qk_acc = INPUT0_VAL_ZERO;
SOFTMAX_ACCUMULATOR_TYPE qk_acc = SOFTMAX_ACCUMULATOR_VAL_ZERO;

#define KEY_VEC_SIZE SUBGROUP_SIZE
unroll_for (uint qk_idx = 0; qk_idx < HEAD_SIZE / KEY_VEC_SIZE; qk_idx++) {
Expand All @@ -181,9 +181,9 @@ KERNEL(pa_sdpa_opt)(

unroll_for (uint i = 0; i < KEY_VEC_SIZE; i++) {
#if STORE_QUERY_TO_SLM
qk_acc = mad(sub_group_broadcast(q_val, i), k_vals[i], qk_acc);
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val, i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
#else
qk_acc = mad(sub_group_broadcast(q_val[qk_idx], i), k_vals[i], qk_acc);
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val[qk_idx], i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
#endif
}
}
Expand All @@ -196,7 +196,7 @@ KERNEL(pa_sdpa_opt)(
#endif

if (token_idx >= seq_len)
qk_acc = INPUT0_VAL_MIN;
qk_acc = SOFTMAX_ACCUMULATOR_VAL_MIN;

qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc));

Expand Down Expand Up @@ -235,7 +235,7 @@ KERNEL(pa_sdpa_opt)(
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
#endif
SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) - qk_max);
slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new);
slm_qk_vals[local_data_idx] = qk_new;

exp_sum += qk_new;
}
Expand Down Expand Up @@ -266,7 +266,7 @@ KERNEL(pa_sdpa_opt)(
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
#endif
SOFTMAX_ACCUMULATOR_TYPE qk_new = TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) / exp_sum;
slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new);
slm_qk_vals[local_data_idx] = qk_new;
}
}

Expand Down
Loading