You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for your work! However, when I run the example you gave in the readme on my 3090, the code could not work. When I add .half to convert them in fp16, it works. But I would like to use in fp32. How to fix this?
The text was updated successfully, but these errors were encountered:
This is because I specify objects / initialize some things to FP16 by default -- this was done in the Triton example, which I based this repo off of. I've unfortunately been super busy recently, but if you want to fix this yourself, you'll have to make sure nothing gets explicitly converted to FP16. See https://github.com/alexzhang13/flashattention2-custom-mask/blob/main/fa2_custom_mask/fa2_fwd.py#L87 for example. Find and modify all these FP16 things to the dtype of the input tensor and it should work!
Thanks for your work! However, when I run the example you gave in the readme on my 3090, the code could not work. When I add
.half
to convert them in fp16, it works. But I would like to use in fp32. How to fix this?The text was updated successfully, but these errors were encountered: