-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
illegal memory access with seqlen = 2048 and dimension = 64 #2
Comments
Hm, interesting, let me take a look this weekend. I was having issues earlier with the backwards pass and I thought I fixed them by undoing my changes back to match the original implementation, but I will investigate. Thanks for catching this! |
@Uwwal I'm unable to replicate this error on my side currently. I wasn't using torch.benchmark, but I just ran:
What version of Triton / PyTorch / CUDA do you have installed? |
triton 3.0.0 |
I ran this code and it runs well. |
You can reproduce the error by add backward, like import torch
import torch.nn.functional as F
import triton
import torch.autograd.profiler as profiler
from fa2_custom_mask import flash_attention_custom_mask
torch.manual_seed(42)
def func(y, grad):
y.backward(grad, retain_graph=True)
batch_size=8
num_heads=28
seq_len=4096
dimension=128
q = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.bfloat16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.bfloat16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.bfloat16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
mask = torch.tril(torch.ones((batch_size, num_heads, seq_len, seq_len), dtype=torch.bool, device="cuda", requires_grad=False))
torch.cuda.synchronize()
torch.cuda.empty_cache()
y = flash_attention_custom_mask(q, k, v, mask, 0.5).half()
grad_output = torch.ones_like(y)
y.backward(grad_output)
for x in [q,k,v]:
print(x.grad)
print(y) The reason is # line 45
maskT_ptrs = mask + offs_m[None, :].to(tl.int64) * mask_stride_tok + offs_n[:, None].to(tl.int64) * mask_stride_tokk
...
# line 105
mask_ptrs = mask + offs_m[:, None].to(tl.int64) * mask_stride_tok + offs_n[None, :].to(tl.int64) * mask_stride_tokk
...
# line 163
m_adj = mask_stride_h.to(tl.int64) * (bhid % H) + mask_stride_z.to(tl.int64) * (bhid // H) |
params:
batch_size: 64, num_heads: 16, seq_len: 2048, dimension: 64, type:torch.float16, use mask, A800
example:
error:
some attempts:
The issue seems to be related to incorrect tiling block size. Therefore, I modified the config parameters in the fa2_fwd file.
and this issue still occurs.
looking forward to your response.
The text was updated successfully, but these errors were encountered: