Skip to content

Commit

Permalink
if the hybrid module is an RNN, allow for folding it across the seque…
Browse files Browse the repository at this point in the history
…nce for efficiency, and to answer the question just how much recurrence is needed
  • Loading branch information
lucidrains committed Jan 5, 2025
1 parent c51ecd3 commit b28d82a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.44.2',
version = '1.44.4',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
5 changes: 4 additions & 1 deletion tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,8 @@ def test_hyper_connections(tanh):

model(x)

def test_hybrid():
@pytest.mark.parametrize('hybrid_axial_dim', (1, 4))
def test_hybrid(hybrid_axial_dim):
from torch.nn import GRU

dec = TransformerWrapper(
Expand All @@ -625,6 +626,7 @@ def test_hybrid():
depth = 6,
heads = 8,
attn_dim_head = 64,
attn_hybrid_fold_axial_dim = hybrid_axial_dim,
attn_hybrid_module = GRU(128, 64 * 8, batch_first = True)
)
)
Expand All @@ -641,6 +643,7 @@ def test_hybrid():
depth = 6,
heads = 8,
attn_dim_head = 64,
attn_hybrid_fold_axial_dim = hybrid_axial_dim,
attn_hybrid_module = GRU(128, 64 * 4, batch_first = True, bidirectional = True)
)
)
Expand Down
45 changes: 43 additions & 2 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.amp import autocast
import torch.nn.functional as F
from torch import nn, einsum, Tensor
from torch.utils._pytree import tree_flatten
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.nn import Module, ModuleList, ModuleDict

from functools import partial, wraps
Expand Down Expand Up @@ -966,6 +966,42 @@ def forward(self, x, **kwargs):
x = torch.cat((*segments_to_shift, *rest), dim = -1)
return self.fn(x, **kwargs)

class FoldAxially(Module):
def __init__(
self,
axial_dim,
fn: Module
):
super().__init__()
self.fn = fn
self.axial_dim = axial_dim # will fold the sequence as rearrange("b (n axial_dim) ... -> (b axial_dim) n ...")

def forward(
self,
x,
**kwargs
):
if self.axial_dim == 1:
return self.fn(x, **kwargs)

seq_len, axial_dim = x.shape[1], self.axial_dim

next_multiple = math.ceil(seq_len / axial_dim) * axial_dim
x = pad_at_dim(x, (0, next_multiple - seq_len), dim = 1)

x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)

out = self.fn(x, **kwargs)

(out, *rest_out), tree_spec = tree_flatten(out)

out = rearrange(out, '(b axial_dim) n ... -> b (n axial_dim) ...', axial_dim = axial_dim)

out = out[:, :seq_len]
out = tree_unflatten((out, *rest_out), tree_spec)

return out

# post branch operator

class LayerScale(Module):
Expand Down Expand Up @@ -1140,6 +1176,7 @@ def __init__(
custom_attn_fn: Callable | None = None,
hybrid_module: Module | None = None,
hybrid_mask_kwarg: str | None = None,
hybrid_fold_axial_dim: int | None = None,
one_kv_head = False,
kv_heads = None,
shared_kv = False,
Expand Down Expand Up @@ -1341,8 +1378,12 @@ def __init__(

# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676

self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
hybrid_module = maybe(deepcopy)(hybrid_module)

if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)

self.hybrid_module = hybrid_module
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths

# output dimension by default same as input, but can be overridden
Expand Down

0 comments on commit b28d82a

Please sign in to comment.