From 85cffc6ccc6da8397aa1dc6fa07a07b2781c7a89 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 31 Oct 2024 12:01:49 -0700 Subject: [PATCH] improvise a bidirectional forgetting transformer --- setup.py | 2 +- x_transformers/x_transformers.py | 25 +++++++++++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index f806650c..75aaa19b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.42.4', + version = '1.42.5', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index a8d2a675..34750f31 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -512,12 +512,15 @@ def __init__( self, dim, heads, + causal = True, bias_init = 5., - post_log_scale = 1. + post_log_scale = 1., ): super().__init__() - linear = nn.Linear(dim, heads) + self.causal = causal + + linear = nn.Linear(dim, heads * (1 if causal else 2)) self.to_forget_gates = nn.Sequential( linear, @@ -529,9 +532,21 @@ def __init__( self.post_log_scale = post_log_scale def forward(self, x): + bidirectional = not self.causal + forget_gates = self.to_forget_gates(x) * self.post_log_scale + forget_gates = forget_gates.cumsum(dim = -1) + + if bidirectional: + forget_gates, forget_gates_reversed = forget_gates.chunk(2, dim = 1) + forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates) + + if bidirectional: + forget_gates_reversed = einx.subtract('b h j, b h i -> b h i j', forget_gates_reversed, forget_gates_reversed) + forget_gates = forget_gates.tril() + forget_gates_reversed.triu() + return forget_gates class PerRowDataDependentAlibi(Module): @@ -541,10 +556,13 @@ def __init__( self, dim, heads, + causal = True, dim_head = 8, post_log_scale = 1. ): super().__init__() + assert causal, 'bidirectional not supported yet' + self.scale = dim_head ** -0.5 linear = nn.Linear(dim, heads * dim_head * 2, bias = False) @@ -1138,10 +1156,9 @@ def __init__( self.data_dependent_alibi = None if data_dependent_alibi: - assert causal, 'data dependent alibi only works for autoregressive for now until further research' dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi - dda_kwargs = dict(dim = dim, heads = heads) + dda_kwargs = dict(dim = dim, heads = heads, causal = causal) if data_dependent_alibi_per_row: dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)