-
Notifications
You must be signed in to change notification settings - Fork 637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update get_split_k to fix a performance regression on FA decode #1040
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1040 +/- ##
==========================================
- Coverage 59.64% 59.58% -0.06%
==========================================
Files 114 114
Lines 10223 10232 +9
==========================================
Hits 6097 6097
- Misses 4126 4135 +9
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@scxiao thanks! Do you have the performance results (on both AMD and H100)? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just confirming: this PR will not regress H100 perf?
xformers/ops/fmha/triton_splitk.py
Outdated
@@ -1028,19 +1028,31 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: | |||
@classmethod | |||
def get_split_k(cls, B: int, G: int, H: int, Mk: int) -> int: | |||
"""Heuristic for the number of splits""" | |||
print(f"B = {B}, G = {G}, H = {H}, Mk = {Mk}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, it seems like I pushed an unexpected commit to this PR. I just convert this PR to draft. will let you know when I cleaned up the code.
split_k_upper_bound = 512 | ||
else: | ||
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128 | ||
split_k_stop_val = Mk / max_chunk_size | ||
split_k_upper_bound = 64 | ||
|
||
while split_k > split_k_stop_val: | ||
split_k = split_k // 2 | ||
while split_k > split_k_stop_val: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confirming if that this is an intended change ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, it seems like I pushed an unexpected commit to this PR. I just convert this PR to draft. will let you know when I cleaned up the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just double checked it, we try to make the changes only impact the hip side, so make the above changes here.
I will check and get back to you |
Thanks. I will check and get back to you. |
Changes in this PR only applies to the calculation of split_k in hip backend, so it will not impact cuda side. |
HI @jianyuh, @shagunsodhani, when you get a chance, could you please take a look at this PR? Thanks |
xformers/ops/fmha/triton_splitk.py
Outdated
if torch.version.hip: | ||
max_chunk_size = 64 | ||
split_k_stop_val = min(Mk / max_chunk_size, 1024 / (B * G * H)) | ||
split_k_stop_val = 1024 / (B * G * H) | ||
while split_k > 0 and Mk / (split_k - 1) < max_chunk_size: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be while split_k > 1 to prevent division by 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, yes, good catch, should be split_k > 1
.
LGTM. This is a small change on nvidia (adding |
Thanks, I reverted that change related to nvidia. |
Hi all, I am wondering whether anyone has additional comments for this PR? If not, could we get this PR merged? Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, change only affects AMD and verified internally in D57316421
Hi @jianyuh, could you please help get this PR merged if no additional comments? Thanks |
What does this PR do?
Fixes a performance regress for FA decode. Changes are related to the function
get_split_k()
, which is to compute the number of splitK in the fa decode in Triton.This PR also added a test to verify the correctness of different implementations of FA decoders, you can run the tests as:
pytest benchmark_attn_decoding.py -v
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.