Skip to content

Commit

Permalink
add data dependent alibi from the forgetting transformers paper from …
Browse files Browse the repository at this point in the history
…iclr 2025
  • Loading branch information
lucidrains committed Oct 29, 2024
1 parent b84babb commit 222f360
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 7 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2330,4 +2330,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
}
```

```bibtex
@inproceedings{anonymous2024forgetting,
title = {Forgetting Transformer: Softmax Attention with a Forget Gate},
author = {Anonymous},
booktitle = {Submitted to The Thirteenth International Conference on Learning Representations},
year = {2024},
url = {https://openreview.net/forum?id=q2Lnyegkr8},
note = {under review}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.40.11',
version = '1.41.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
17 changes: 17 additions & 0 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,20 @@ def test_value_residual():
x = torch.randint(0, 20000, (2, 1024))

model(x)

def test_forgetting_transformer():

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 128,
depth = 6,
heads = 8,
attn_data_dependent_alibi = False
)
)

x = torch.randint(0, 20000, (2, 1024))

embed = model(x)
80 changes: 74 additions & 6 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def log(t, eps = 1e-20):
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max

def reverse_cumsum(t, dim = -1):
t = t.flip(dims = (dim,))
t = t.cumsum(dim = dim)
t = t.flip(dims = (dim,))
return t

def l2norm(t, groups = 1):
t = rearrange(t, '... (g d) -> ... g d', g = groups)
t = F.normalize(t, p = 2, dim = -1)
Expand Down Expand Up @@ -352,8 +358,10 @@ def __init__ (
self.soft_onehot = soft_onehot
self.soft_onehot_temp = soft_onehot_temp

if soft_onehot:
self.register_buffer('positions', torch.arange(max_pos))
if not soft_onehot:
return

self.register_buffer('positions', torch.arange(max_pos))

def forward(self, query, attn_logits):

Expand All @@ -375,7 +383,7 @@ def forward(self, query, attn_logits):
logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)

if self.soft_onehot:
diff_pos = (pos[..., None] - self.positions).abs()
diff_pos = einx.subtract('i, j -> i j', pos, self.positions).abs()
soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim = -1)
cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
else:
Expand Down Expand Up @@ -491,6 +499,44 @@ def forward(self, i, j):

return self.bias

class DataDependentAlibi(Module):
""" https://openreview.net/forum?id=q2Lnyegkr8 """

def __init__(
self,
dim,
heads
):
super().__init__()

linear = nn.Linear(dim, heads)

self.to_forget_gates = nn.Sequential(
linear,
Rearrange('b n h -> b h n'),
nn.Sigmoid()
)

nn.init.constant_(linear.bias, 5.)

def forward(self, x):
seq = x.shape[-2]

forget_gates = self.to_forget_gates(x).log()
forget_gates = repeat(forget_gates, 'b h j -> b h i j', i = seq)

# causal mask out, including diagonal (so token to itself attention is never masked out)

causal_mask = torch.ones((seq, seq), dtype = torch.bool, device = x.device).triu()

forget_gates = forget_gates.masked_fill(causal_mask, 0.)

# reverse cumulative sum in log space (equivalent to cumprod)

forget_gates = reverse_cumsum(forget_gates)

return forget_gates

class RotaryEmbedding(Module):
def __init__(
self,
Expand Down Expand Up @@ -939,6 +985,7 @@ def __init__(
tensor_product = False, # https://arxiv.org/abs/2208.06061
add_zero_kv = False, # same as add_zero_attn in pytorch
rotary_embed_values = False,
data_dependent_alibi = False,
use_cope = False,
cope_max_pos = 16,
cope_soft_onehot_pos = False,
Expand Down Expand Up @@ -1042,6 +1089,19 @@ def __init__(
soft_onehot = cope_soft_onehot_pos
)

# data dependent alibi
# https://openreview.net/forum?id=q2Lnyegkr8

self.data_dependent_alibi = None

if data_dependent_alibi:
assert causal, 'data dependent alibi only works for autoregressive for now until further research'

self.data_dependent_alibi = DataDependentAlibi(
dim,
heads = heads
)

# attend class - includes core attention algorithm + talking heads

self.attend = Attend(
Expand Down Expand Up @@ -1252,6 +1312,11 @@ def forward(
attn_bias = rel_pos(i, j)
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values

# prepare data dependent alibi from forgetting transformers paper, if needed

if exists(self.data_dependent_alibi):
attn_bias = self.data_dependent_alibi(x)

# if previous values passed in for residual, either invoke resformer or neutreno

if exists(value_residual):
Expand Down Expand Up @@ -1389,10 +1454,13 @@ def __init__(
attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
cross_attn_kwargs, kwargs = groupby_prefix_and_trim('cross_attn_', kwargs)

dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
data_dependent_alibi = attn_kwargs.get('data_dependent_alibi', False)
neutreno_value_residual = attn_kwargs.get('neutreno_value_residual', False)

assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'

dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
add_value_residual |= attn_kwargs.get('neutreno_value_residual', False)
add_value_residual |= neutreno_value_residual

self.dim = dim
self.causal = causal
Expand All @@ -1405,7 +1473,7 @@ def __init__(
assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None

assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), or T5 relative positional bias'
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'

# relative positional bias
Expand Down

0 comments on commit 222f360

Please sign in to comment.