Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update get_split_k to fix a performance regression on FA decode #1040

Merged
merged 17 commits into from
May 17, 2024

Conversation

scxiao
Copy link
Contributor

@scxiao scxiao commented Apr 30, 2024

What does this PR do?

Fixes a performance regress for FA decode. Changes are related to the function get_split_k(), which is to compute the number of splitK in the fa decode in Triton.

This PR also added a test to verify the correctness of different implementations of FA decoders, you can run the tests as:
pytest benchmark_attn_decoding.py -v

Before submitting

  • Did you have fun?
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 30, 2024
@codecov-commenter
Copy link

codecov-commenter commented Apr 30, 2024

Codecov Report

Attention: Patch coverage is 0% with 13 lines in your changes are missing coverage. Please review.

Project coverage is 59.58%. Comparing base (22d092e) to head (79da1da).
Report is 1 commits behind head on main.

Files Patch % Lines
xformers/ops/fmha/triton_splitk.py 0.00% 13 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1040      +/-   ##
==========================================
- Coverage   59.64%   59.58%   -0.06%     
==========================================
  Files         114      114              
  Lines       10223    10232       +9     
==========================================
  Hits         6097     6097              
- Misses       4126     4135       +9     
Flag Coverage Δ
Python 59.58% <0.00%> (-0.06%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@xw285cornell
Copy link
Contributor

@scxiao thanks! Do you have the performance results (on both AMD and H100)?

Copy link
Member

@jianyuh jianyuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirming: this PR will not regress H100 perf?

@@ -1028,19 +1028,31 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
@classmethod
def get_split_k(cls, B: int, G: int, H: int, Mk: int) -> int:
"""Heuristic for the number of splits"""
print(f"B = {B}, G = {G}, H = {H}, Mk = {Mk}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it seems like I pushed an unexpected commit to this PR. I just convert this PR to draft. will let you know when I cleaned up the code.

split_k_upper_bound = 512
else:
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
split_k_stop_val = Mk / max_chunk_size
split_k_upper_bound = 64

while split_k > split_k_stop_val:
split_k = split_k // 2
while split_k > split_k_stop_val:
Copy link

@shagunsodhani shagunsodhani May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirming if that this is an intended change ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it seems like I pushed an unexpected commit to this PR. I just convert this PR to draft. will let you know when I cleaned up the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double checked it, we try to make the changes only impact the hip side, so make the above changes here.

@scxiao scxiao marked this pull request as draft May 8, 2024 16:51
@scxiao
Copy link
Contributor Author

scxiao commented May 8, 2024

@scxiao thanks! Do you have the performance results (on both AMD and H100)?

I will check and get back to you

@scxiao
Copy link
Contributor Author

scxiao commented May 8, 2024

@scxiao thanks! Do you have the performance results (on both AMD and H100)?

Thanks. I will check and get back to you.

@scxiao scxiao force-pushed the update_get_splitK branch from 5ece4ca to 1e47433 Compare May 13, 2024 19:19
@scxiao scxiao marked this pull request as ready for review May 13, 2024 19:22
@scxiao
Copy link
Contributor Author

scxiao commented May 13, 2024

Just confirming: this PR will not regress H100 perf?

Changes in this PR only applies to the calculation of split_k in hip backend, so it will not impact cuda side.

@scxiao
Copy link
Contributor Author

scxiao commented May 14, 2024

HI @jianyuh, @shagunsodhani, when you get a chance, could you please take a look at this PR? Thanks

if torch.version.hip:
max_chunk_size = 64
split_k_stop_val = min(Mk / max_chunk_size, 1024 / (B * G * H))
split_k_stop_val = 1024 / (B * G * H)
while split_k > 0 and Mk / (split_k - 1) < max_chunk_size:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be while split_k > 1 to prevent division by 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, yes, good catch, should be split_k > 1.

@sgrigory sgrigory requested a review from bottler May 14, 2024 08:15
@bottler
Copy link
Contributor

bottler commented May 14, 2024

LGTM. This is a small change on nvidia (adding bh-1), and a larger change on amd.

@scxiao
Copy link
Contributor Author

scxiao commented May 14, 2024

LGTM. This is a small change on nvidia (adding bh-1), and a larger change on amd.

Thanks, I reverted that change related to nvidia.

@scxiao
Copy link
Contributor Author

scxiao commented May 16, 2024

Hi all, I am wondering whether anyone has additional comments for this PR? If not, could we get this PR merged? Thanks.

Copy link

@zixi-qi zixi-qi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, change only affects AMD and verified internally in D57316421

@scxiao
Copy link
Contributor Author

scxiao commented May 17, 2024

Hi @jianyuh, could you please help get this PR merged if no additional comments? Thanks

@bottler bottler merged commit 6e1718b into facebookresearch:main May 17, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants