Skip to content

Commit

Permalink
address #303, allow for rotary embedding for cross attention if conte…
Browse files Browse the repository at this point in the history
…xt is in shared positional space as input
  • Loading branch information
lucidrains committed Dec 10, 2024
1 parent 66be236 commit 70c59b6
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.42.24',
version = '1.42.25',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
29 changes: 29 additions & 0 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,32 @@ def test_laser():
x = torch.randint(0, 20000, (2, 1024))

model(x)

@pytest.mark.parametrize('cross_attn_rotary', (True, False))
def test_cross_attn_rotary(
cross_attn_rotary: bool
):

x = torch.randn((1, 64, 256))
mask = torch.ones((1, 64)).bool()
context = torch.randn((1, 128, 512))
context_mask = torch.ones((1, 128)).bool()

model = Encoder(
dim = 256,
depth = 4,
heads = 4,
rotary_pos_emb = True,
cross_attend = True,
cross_attn_dim_context = 512
)

context_pos = torch.arange(128)

embed = model(
x = x,
mask = mask,
context = context,
context_pos = context_pos if cross_attn_rotary else None,
context_mask = context_mask
)
40 changes: 32 additions & 8 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,7 @@ def forward(
rel_pos = None,
attn_bias = None,
rotary_pos_emb = None,
context_rotary_pos_emb = None,
pos = None, # for custom alibi positions
prev_attn = None,
mem = None,
Expand Down Expand Up @@ -1355,11 +1356,19 @@ def forward(
q = q * self.qk_norm_q_scale
k = k * self.qk_norm_k_scale

if exists(rotary_pos_emb) and not has_context:
if exists(rotary_pos_emb):

freqs, xpos_scale = rotary_pos_emb
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)

q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)

if has_context:
# override with `context_rotary_pos_emb` if provided

freqs, xpos_scale = context_rotary_pos_emb
_, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)

k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)

if self.rotary_embed_values:
Expand Down Expand Up @@ -1916,6 +1925,7 @@ def forward(
return_hiddens = False,
rotary_pos_emb = None,
pos = None,
context_pos = None,
attn_bias = None,
condition = None,
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
Expand Down Expand Up @@ -1975,14 +1985,28 @@ def forward(

# rotary positions

if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
cross_attn_rotary_pos_emb = dict()

if exists(self.rotary_pos_emb):
if not exists(rotary_pos_emb):
maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0

if not exists(pos):
pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len

rotary_pos_emb = self.rotary_pos_emb(pos)

# allow for rotary positions for context if provided

if not exists(pos):
pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
if exists(context_pos):
assert self.cross_attend
context_rotary_pos_emb = self.rotary_pos_emb(context_pos)

rotary_pos_emb = self.rotary_pos_emb(pos)
cross_attn_rotary_pos_emb.update(
rotary_pos_emb = rotary_pos_emb,
context_rotary_pos_emb = context_rotary_pos_emb
)

# assume cached key / values

Expand Down Expand Up @@ -2107,7 +2131,7 @@ def forward(
if layer_type == 'a':
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
elif layer_type == 'c':
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
elif layer_type == 'f':
out = block(x)

Expand Down

0 comments on commit 70c59b6

Please sign in to comment.