Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError: First input (fp16) and second input (fp32) must have the same dtype! #14

Open
3bobo opened this issue Oct 4, 2024 · 1 comment

Comments

@3bobo
Copy link

3bobo commented Oct 4, 2024

Thanks for your work! However, when I run the example you gave in the readme on my 3090, the code could not work. When I add .half to convert them in fp16, it works. But I would like to use in fp32. How to fix this?

@alexzhang13
Copy link
Owner

This is because I specify objects / initialize some things to FP16 by default -- this was done in the Triton example, which I based this repo off of. I've unfortunately been super busy recently, but if you want to fix this yourself, you'll have to make sure nothing gets explicitly converted to FP16. See https://github.com/alexzhang13/flashattention2-custom-mask/blob/main/fa2_custom_mask/fa2_fwd.py#L87 for example. Find and modify all these FP16 things to the dtype of the input tensor and it should work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants