Skip to content

Commit

Permalink
address #305
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 23, 2024
1 parent cdf51f7 commit 003275c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.43.1',
version = '1.43.2',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
6 changes: 4 additions & 2 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ def test_custom_alibi(flash: bool):

logits = model(x, pos = pos)

def test_custom_rotary_pos_emb():
@pytest.mark.parametrize('rotary_xpos', (True, False))
def test_custom_rotary_pos_emb(rotary_xpos):
from einops import repeat

model = TransformerWrapper(
Expand All @@ -419,7 +420,8 @@ def test_custom_rotary_pos_emb():
dim = 512,
depth = 2,
heads = 8,
rotary_pos_emb = True
rotary_pos_emb = True,
rotary_xpos = rotary_xpos
)
)

Expand Down
2 changes: 1 addition & 1 deletion x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def forward(self, t):
return freqs, 1.

power = (t - (max_pos // 2)) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = self.scale ** rearrange(power, '... n -> ... n 1')
scale = torch.stack((scale, scale), dim = -1)
scale = rearrange(scale, '... d r -> ... (d r)')

Expand Down

0 comments on commit 003275c

Please sign in to comment.