Skip to content

Commit

Permalink
Merge branch 'chris-use' into rotary-emb-custom-pos
Browse files Browse the repository at this point in the history
  • Loading branch information
Aceticia authored Nov 29, 2024
2 parents 243d298 + bacdd84 commit e7bd229
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
11 changes: 7 additions & 4 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ def test_neo_mlp():
out = mlp(x)
assert out.shape == (3, 7)

def test_custom_alibi():
@pytest.mark.parametrize('flash', (True, False))
def test_custom_alibi(flash: bool):

model = TransformerWrapper(
num_tokens = 20_000,
Expand All @@ -397,7 +398,8 @@ def test_custom_alibi():
dim = 512,
depth = 2,
heads = 8,
alibi_pos_bias = True
alibi_pos_bias = True,
attn_flash = flash
)
)

Expand Down Expand Up @@ -429,8 +431,8 @@ def test_custom_rotary_pos_emb():
logits2 = model(x)
assert torch.allclose(logits1, logits2)

def test_custom_alibi_across_heads():

@pytest.mark.parametrize('flash', (True, False))
def test_custom_alibi_across_heads(flash: bool):
model = Decoder(
dim = 512,
depth = 2,
Expand All @@ -439,6 +441,7 @@ def test_custom_alibi_across_heads():
rel_pos_kwargs = dict(
slopes = [1, 1]
),
attn_flash = flash
)

x = torch.randn(2, 4, 512)
Expand Down
2 changes: 1 addition & 1 deletion x_transformers/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def flash_attn(
# convert from bool to float

if exists(attn_bias):
attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
attn_bias = attn_bias.expand(batch, heads, -1, -1)

# if mask given, the mask would already contain the causal mask from above logic
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
Expand Down

0 comments on commit e7bd229

Please sign in to comment.