You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Another error is, it only supports fp16 inputs, so I convert q/k/v from fp32 to fp16, and convert the out from fp16 to fp32 after fa2.
After these corrections, I can run fa2-cm normally.
However, the results seem bad, because the gradients explode.
I want to ask the possible reasons. It's because of my aforementioned modifications?
Looking forward to your reply~
The text was updated successfully, but these errors were encountered:
For FP32 support, I need to make a minor edit. This was attempted by someone else earlier, but their code was buggy so I had to revert it. I will make this edit later this week and let you know!
For FP32 support, I need to make a minor edit. This was attempted by someone else earlier, but their code was buggy so I had to revert it. I will make this edit later this week and let you know!
Do you have any ideas why fp32 will produce wrong results in backward?
Oh it shouldn't, basically I specified in the code that things should be FP16, which was taken from the original Triton example. It should take the dtype of the tensors that are being fed in instead.
Hi, thanks for your work. I'm trying your fa2-cm but it raises error because of the following assertion:
I solve this problem by using
.contiguous()
as follows:Another error is, it only supports fp16 inputs, so I convert q/k/v from fp32 to fp16, and convert the out from fp16 to fp32 after fa2.
After these corrections, I can run fa2-cm normally.
However, the results seem bad, because the gradients explode.
I want to ask the possible reasons. It's because of my aforementioned modifications?
Looking forward to your reply~
The text was updated successfully, but these errors were encountered: