From 4c3e62acff0cc1aeae2b4c2b60b5ae7b651e7d55 Mon Sep 17 00:00:00 2001 From: Xujin Chris Liu <36082433+Aceticia@users.noreply.github.com> Date: Fri, 29 Nov 2024 15:23:29 -0500 Subject: [PATCH 1/2] Custom pos alibi flash attn fix (#300) * handle custom pos for alibi in flash attn * test for custom pos alibi+flash attn --- tests/test_x_transformers.py | 10 +++++++--- x_transformers/attend.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_x_transformers.py b/tests/test_x_transformers.py index cd1be33d..18c03b48 100644 --- a/tests/test_x_transformers.py +++ b/tests/test_x_transformers.py @@ -388,7 +388,8 @@ def test_neo_mlp(): out = mlp(x) assert out.shape == (3, 7) -def test_custom_alibi(): +@pytest.mark.parametrize('flash', (True, False)) +def test_custom_alibi(flash: bool): model = TransformerWrapper( num_tokens = 20_000, @@ -397,7 +398,8 @@ def test_custom_alibi(): dim = 512, depth = 2, heads = 8, - alibi_pos_bias = True + alibi_pos_bias = True, + attn_flash = flash ) ) @@ -407,7 +409,8 @@ def test_custom_alibi(): logits = model(x, pos = pos) -def test_custom_alibi_across_heads(): +@pytest.mark.parametrize('flash', (True, False)) +def test_custom_alibi_across_heads(flash: bool): model = Decoder( dim = 512, @@ -417,6 +420,7 @@ def test_custom_alibi_across_heads(): rel_pos_kwargs = dict( slopes = [1, 1] ), + attn_flash = flash ) x = torch.randn(2, 4, 512) diff --git a/x_transformers/attend.py b/x_transformers/attend.py index d354f915..c2bad988 100644 --- a/x_transformers/attend.py +++ b/x_transformers/attend.py @@ -370,7 +370,7 @@ def flash_attn( # convert from bool to float if exists(attn_bias): - attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1) + attn_bias = attn_bias.expand(batch, heads, -1, -1) # if mask given, the mask would already contain the causal mask from above logic # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number From 8c993f0dd850d4b15a737c271c55db54b1db2e6b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 29 Nov 2024 12:23:59 -0800 Subject: [PATCH 2/2] 1.42.20 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 37a3f636..49086114 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.42.19', + version = '1.42.20', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang',