Skip to content

Commit

Permalink
follow some suggestions from @faresobeid
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 29, 2024
1 parent 222f360 commit e2bab6d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 23 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.41.0',
version = '1.41.1',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
26 changes: 4 additions & 22 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,6 @@ def log(t, eps = 1e-20):
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max

def reverse_cumsum(t, dim = -1):
t = t.flip(dims = (dim,))
t = t.cumsum(dim = dim)
t = t.flip(dims = (dim,))
return t

def l2norm(t, groups = 1):
t = rearrange(t, '... (g d) -> ... g d', g = groups)
t = F.normalize(t, p = 2, dim = -1)
Expand Down Expand Up @@ -514,27 +508,15 @@ def __init__(
self.to_forget_gates = nn.Sequential(
linear,
Rearrange('b n h -> b h n'),
nn.Sigmoid()
nn.LogSigmoid()
)

nn.init.constant_(linear.bias, 5.)

def forward(self, x):
seq = x.shape[-2]

forget_gates = self.to_forget_gates(x).log()
forget_gates = repeat(forget_gates, 'b h j -> b h i j', i = seq)

# causal mask out, including diagonal (so token to itself attention is never masked out)

causal_mask = torch.ones((seq, seq), dtype = torch.bool, device = x.device).triu()

forget_gates = forget_gates.masked_fill(causal_mask, 0.)

# reverse cumulative sum in log space (equivalent to cumprod)

forget_gates = reverse_cumsum(forget_gates)

forget_gates = self.to_forget_gates(x)
forget_gates = forget_gates.cumsum(dim = -1)
forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
return forget_gates

class RotaryEmbedding(Module):
Expand Down

0 comments on commit e2bab6d

Please sign in to comment.