From e2bab6d04e4d9a4ca05843dda3cc2fb6a75e6bf0 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 29 Oct 2024 10:01:25 -0700 Subject: [PATCH] follow some suggestions from @faresobeid --- setup.py | 2 +- x_transformers/x_transformers.py | 26 ++++---------------------- 2 files changed, 5 insertions(+), 23 deletions(-) diff --git a/setup.py b/setup.py index 33c48144..7e0780e2 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.41.0', + version = '1.41.1', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 7f58c401..cc80f6d4 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -101,12 +101,6 @@ def log(t, eps = 1e-20): def max_neg_value(tensor): return -torch.finfo(tensor.dtype).max -def reverse_cumsum(t, dim = -1): - t = t.flip(dims = (dim,)) - t = t.cumsum(dim = dim) - t = t.flip(dims = (dim,)) - return t - def l2norm(t, groups = 1): t = rearrange(t, '... (g d) -> ... g d', g = groups) t = F.normalize(t, p = 2, dim = -1) @@ -514,27 +508,15 @@ def __init__( self.to_forget_gates = nn.Sequential( linear, Rearrange('b n h -> b h n'), - nn.Sigmoid() + nn.LogSigmoid() ) nn.init.constant_(linear.bias, 5.) def forward(self, x): - seq = x.shape[-2] - - forget_gates = self.to_forget_gates(x).log() - forget_gates = repeat(forget_gates, 'b h j -> b h i j', i = seq) - - # causal mask out, including diagonal (so token to itself attention is never masked out) - - causal_mask = torch.ones((seq, seq), dtype = torch.bool, device = x.device).triu() - - forget_gates = forget_gates.masked_fill(causal_mask, 0.) - - # reverse cumulative sum in log space (equivalent to cumprod) - - forget_gates = reverse_cumsum(forget_gates) - + forget_gates = self.to_forget_gates(x) + forget_gates = forget_gates.cumsum(dim = -1) + forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates) return forget_gates class RotaryEmbedding(Module):