Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

illegal memory access with seqlen = 2048 and dimension = 64 #2

Open
Uwwal opened this issue Aug 1, 2024 · 5 comments
Open

illegal memory access with seqlen = 2048 and dimension = 64 #2

Uwwal opened this issue Aug 1, 2024 · 5 comments

Comments

@Uwwal
Copy link
Contributor

Uwwal commented Aug 1, 2024

params:

batch_size: 64, num_heads: 16, seq_len: 2048, dimension: 64, type:torch.float16, use mask, A800

example:

def func(y, grad):
    y.backward(grad, retain_graph=True)

q = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())

mask = torch.tril(torch.ones((batch_size, num_heads, seq_len, seq_len), dtype=torch.uint8, device="cuda", requires_grad=False))

t = benchmark.Timer(
    stmt="flash_attention_custom_mask(q, k, v, mask, scale)",
    globals={"flash_attention_custom_mask": flash_attention_custom_mask, "q": q, "k": k,"v":v,"mask":mask, "scale":0.5},
    num_threads=torch.get_num_threads(),
)
fwd_time = t.timeit(10).mean * 1000

torch.cuda.synchronize()
torch.cuda.empty_cache()

y = flash_attention_custom_mask(q, k, v, mask, 0.5).half()

for x in [q,k,v]:
    x.grad = None

grad = torch.rand_like(y)

t = benchmark.Timer(
    stmt="f(y,grad)",
    globals={"f": func, "y": y, "grad": grad},
    num_threads=torch.get_num_threads(),
)
bwd_time = t.timeit(10).mean * 1000

error:

Traceback (most recent call last):
  File "week5_mask_test.py", line 372, in <module>
    bwd_time = t.timeit(10).mean * 1000
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/benchmark/utils/timer.py", line 274, in timeit
    self._timeit(number=max(int(number // 100), 2))
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/benchmark/utils/timer.py", line 264, in _timeit
    return max(self._timer.timeit(number), 1e-9)
  File "/usr/lib/python3.8/timeit.py", line 177, in timeit
    timing = self.inner(it, self.timer)
  File "<timeit-src>", line 6, in inner
  File "week5_mask_test.py", line 64, in func
    y.backward(grad, retain_graph=True)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/hadoop-perception/flashattention2-custom-mask-main/fa2_custom_mask/fa2_custom_mask.py", line 87, in backward
    _attn_bwd[grid](
  File "/home/hadoop-perception/.local/lib/python3.8/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/home/hadoop-perception/.local/lib/python3.8/site-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/home/hadoop-perception/.local/lib/python3.8/site-packages/triton/backends/nvidia/driver.py", line 365, in __call__
    self.launch(*args, **kwargs)
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

some attempts:

The issue seems to be related to incorrect tiling block size. Therefore, I modified the config parameters in the fa2_fwd file.

configs = [
    triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
    for BM in [64]\   # delete 128
    for BN in [32, 64]\
    for s in ([1] if is_hip() else [3, 4, 7])\
    for w in [4, 8]\
]

and this issue still occurs.

looking forward to your response.

@alexzhang13
Copy link
Owner

Hm, interesting, let me take a look this weekend. I was having issues earlier with the backwards pass and I thought I fixed them by undoing my changes back to match the original implementation, but I will investigate. Thanks for catching this!

@alexzhang13
Copy link
Owner

@Uwwal I'm unable to replicate this error on my side currently. I wasn't using torch.benchmark, but I just ran:

import torch
import torch.nn.functional as F
import triton
import torch.autograd.profiler as profiler

from fa2_custom_mask import flash_attention_custom_mask

def func(y, grad):
    y.backward(grad, retain_graph=True)

batch_size=64
num_heads=16
seq_len=2048
dimension=64
    
q = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())

mask = torch.tril(torch.ones((batch_size, num_heads, seq_len, seq_len), dtype=torch.uint8, device="cuda", requires_grad=False))

torch.cuda.synchronize()
torch.cuda.empty_cache()

y = flash_attention_custom_mask(q, k, v, mask, 0.5).half()

for x in [q,k,v]:
    x.grad = None

grad = torch.rand_like(y)

print(y)
print(grad)

What version of Triton / PyTorch / CUDA do you have installed?

@Uwwal
Copy link
Contributor Author

Uwwal commented Aug 3, 2024

@Uwwal I'm unable to replicate this error on my side currently. I wasn't using torch.benchmark, but I just ran:

import torch
import torch.nn.functional as F
import triton
import torch.autograd.profiler as profiler

from fa2_custom_mask import flash_attention_custom_mask

def func(y, grad):
    y.backward(grad, retain_graph=True)

batch_size=64
num_heads=16
seq_len=2048
dimension=64
    
q = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())

mask = torch.tril(torch.ones((batch_size, num_heads, seq_len, seq_len), dtype=torch.uint8, device="cuda", requires_grad=False))

torch.cuda.synchronize()
torch.cuda.empty_cache()

y = flash_attention_custom_mask(q, k, v, mask, 0.5).half()

for x in [q,k,v]:
    x.grad = None

grad = torch.rand_like(y)

print(y)
print(grad)

What version of Triton / PyTorch / CUDA do you have installed?您安装了哪个版本的 Triton / PyTorch / CUDA?

triton 3.0.0
torch 2.2.0a0+git39901f2
cuda 11.8

@Uwwal
Copy link
Contributor Author

Uwwal commented Aug 5, 2024

@Uwwal I'm unable to replicate this error on my side currently. I wasn't using torch.benchmark, but I just ran:

import torch
import torch.nn.functional as F
import triton
import torch.autograd.profiler as profiler

from fa2_custom_mask import flash_attention_custom_mask

def func(y, grad):
    y.backward(grad, retain_graph=True)

batch_size=64
num_heads=16
seq_len=2048
dimension=64
    
q = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())

mask = torch.tril(torch.ones((batch_size, num_heads, seq_len, seq_len), dtype=torch.uint8, device="cuda", requires_grad=False))

torch.cuda.synchronize()
torch.cuda.empty_cache()

y = flash_attention_custom_mask(q, k, v, mask, 0.5).half()

for x in [q,k,v]:
    x.grad = None

grad = torch.rand_like(y)

print(y)
print(grad)

What version of Triton / PyTorch / CUDA do you have installed?

I ran this code and it runs well.
I guess it's caused by the torch benchmark

@chenlidar
Copy link

You can reproduce the error by add backward, like

import torch
import torch.nn.functional as F
import triton
import torch.autograd.profiler as profiler

from fa2_custom_mask import flash_attention_custom_mask

torch.manual_seed(42)

def func(y, grad):
    y.backward(grad, retain_graph=True)

batch_size=8
num_heads=28
seq_len=4096
dimension=128

    
q = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.bfloat16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.bfloat16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((batch_size, num_heads, seq_len, dimension), dtype=torch.bfloat16, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())

mask = torch.tril(torch.ones((batch_size, num_heads, seq_len, seq_len), dtype=torch.bool, device="cuda", requires_grad=False))

torch.cuda.synchronize()
torch.cuda.empty_cache()

y = flash_attention_custom_mask(q, k, v, mask, 0.5).half()

grad_output = torch.ones_like(y)

y.backward(grad_output)

for x in [q,k,v]:
    print(x.grad)


print(y)

The reason is 8*28*4096*4096>2**31(or 64*16*2048*2048>2**31) causes an int32 overflow.
You can solve this by modifying the fa2_bwd.py like:

# line 45
maskT_ptrs = mask + offs_m[None, :].to(tl.int64) * mask_stride_tok + offs_n[:, None].to(tl.int64) * mask_stride_tokk
...
# line 105
mask_ptrs = mask + offs_m[:, None].to(tl.int64) * mask_stride_tok + offs_n[None, :].to(tl.int64) * mask_stride_tokk
...
# line 163
m_adj = mask_stride_h.to(tl.int64) * (bhid % H) + mask_stride_z.to(tl.int64) * (bhid // H)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants