From 70c59b6defb442b349fffe57e733fa5b14d813b6 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 10 Dec 2024 07:15:05 -0800 Subject: [PATCH] address https://github.com/lucidrains/x-transformers/issues/303, allow for rotary embedding for cross attention if context is in shared positional space as input --- setup.py | 2 +- tests/test_x_transformers.py | 29 +++++++++++++++++++++++ x_transformers/x_transformers.py | 40 +++++++++++++++++++++++++------- 3 files changed, 62 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 4691a733..5b3bcca7 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.42.24', + version = '1.42.25', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/tests/test_x_transformers.py b/tests/test_x_transformers.py index ef231b67..8944115d 100644 --- a/tests/test_x_transformers.py +++ b/tests/test_x_transformers.py @@ -557,3 +557,32 @@ def test_laser(): x = torch.randint(0, 20000, (2, 1024)) model(x) + +@pytest.mark.parametrize('cross_attn_rotary', (True, False)) +def test_cross_attn_rotary( + cross_attn_rotary: bool +): + + x = torch.randn((1, 64, 256)) + mask = torch.ones((1, 64)).bool() + context = torch.randn((1, 128, 512)) + context_mask = torch.ones((1, 128)).bool() + + model = Encoder( + dim = 256, + depth = 4, + heads = 4, + rotary_pos_emb = True, + cross_attend = True, + cross_attn_dim_context = 512 + ) + + context_pos = torch.arange(128) + + embed = model( + x = x, + mask = mask, + context = context, + context_pos = context_pos if cross_attn_rotary else None, + context_mask = context_mask + ) diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index ac820989..c9bae1c0 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -1284,6 +1284,7 @@ def forward( rel_pos = None, attn_bias = None, rotary_pos_emb = None, + context_rotary_pos_emb = None, pos = None, # for custom alibi positions prev_attn = None, mem = None, @@ -1355,11 +1356,19 @@ def forward( q = q * self.qk_norm_q_scale k = k * self.qk_norm_k_scale - if exists(rotary_pos_emb) and not has_context: + if exists(rotary_pos_emb): + freqs, xpos_scale = rotary_pos_emb q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.) q = apply_rotary_pos_emb(q, freqs, q_xpos_scale) + + if has_context: + # override with `context_rotary_pos_emb` if provided + + freqs, xpos_scale = context_rotary_pos_emb + _, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.) + k = apply_rotary_pos_emb(k, freqs, k_xpos_scale) if self.rotary_embed_values: @@ -1916,6 +1925,7 @@ def forward( return_hiddens = False, rotary_pos_emb = None, pos = None, + context_pos = None, attn_bias = None, condition = None, in_attn_cond = None, # https://arxiv.org/abs/2105.04090 @@ -1975,14 +1985,28 @@ def forward( # rotary positions - if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb): - maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows - mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0 + cross_attn_rotary_pos_emb = dict() + + if exists(self.rotary_pos_emb): + if not exists(rotary_pos_emb): + maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows + mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0 + + if not exists(pos): + pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len + + rotary_pos_emb = self.rotary_pos_emb(pos) + + # allow for rotary positions for context if provided - if not exists(pos): - pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len + if exists(context_pos): + assert self.cross_attend + context_rotary_pos_emb = self.rotary_pos_emb(context_pos) - rotary_pos_emb = self.rotary_pos_emb(pos) + cross_attn_rotary_pos_emb.update( + rotary_pos_emb = rotary_pos_emb, + context_rotary_pos_emb = context_rotary_pos_emb + ) # assume cached key / values @@ -2107,7 +2131,7 @@ def forward( if layer_type == 'a': out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True) elif layer_type == 'c': - out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True) + out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True) elif layer_type == 'f': out = block(x)