Skip to content

Commit

Permalink
add option to use mixture of softmax in TransformerWrapper + some cle…
Browse files Browse the repository at this point in the history
…anup
  • Loading branch information
lucidrains committed Sep 25, 2024
1 parent c9e0d2f commit e6c488e
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 35 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2221,4 +2221,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
}
```

```
@article{Yang2017BreakingTS,
title = {Breaking the Softmax Bottleneck: A High-Rank RNN Language Model},
author = {Zhilin Yang and Zihang Dai and Ruslan Salakhutdinov and William W. Cohen},
journal = {ArXiv},
year = {2017},
volume = {abs/1711.03953},
url = {https://api.semanticscholar.org/CorpusID:26238954}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.36.0',
version = '1.37.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
20 changes: 20 additions & 0 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,23 @@ def test_recycling():
model.eval()

eval_logits = model(x, recycle_steps = 3)

def test_mos():
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
mixture_of_softmax = True,
attn_layers = Decoder(
dim = 128,
depth = 6,
heads = 8
)
)

x = torch.randint(0, 20000, (2, 1024))

logits = model(x)

model.eval()

eval_logits = model(x, recycle_steps = 3)
4 changes: 3 additions & 1 deletion x_transformers/autoregressive_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,9 @@ def forward(self, x, return_outputs = False, **kwargs):
**kwargs
)

loss = F.cross_entropy(
loss_fn = F.cross_entropy if not self.net.is_log_prob else F.nll_loss

loss = loss_fn(
rearrange(logits, 'b n c -> b c n'),
target,
ignore_index = ignore_index
Expand Down
105 changes: 72 additions & 33 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from typing import Callable

import math
from random import random, randrange
Expand All @@ -14,7 +15,6 @@
from collections import namedtuple
from contextlib import nullcontext
from dataclasses import dataclass
from typing import List, Dict, Tuple, Callable

from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
Expand All @@ -28,14 +28,16 @@

@dataclass
class LayerIntermediates:
hiddens: List[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture)
hiddens: list[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture)
last_hidden: Tensor | None = None # very last hidden after all attention layers, after the final norm
attn_intermediates: List[Intermediates] | None = None
layer_hiddens: List[Tensor] | None = None
attn_intermediates: list[Intermediates] | None = None
layer_hiddens: list[Tensor] | None = None
attn_z_loss: Tensor | None = None
mems: Tensor | None = None
memory_tokens: Tensor | None = None

LinearNoBias = partial(nn.Linear, bias = False)

# helpers

def exists(val):
Expand Down Expand Up @@ -92,6 +94,9 @@ def Sequential(*modules):

# tensor helpers

def log(t, eps = 1e-20):
return t.clamp(min = eps).log()

def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max

Expand All @@ -114,7 +119,7 @@ def masked_mean(t, mask = None, dim = 1):
den = mask.sum(dim = dim).clamp(min = 1.)
return num / den

def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
if pad == (0, 0):
return t

Expand All @@ -131,7 +136,7 @@ def or_reduce(masks):
# auxiliary loss helpers

def calc_z_loss(
pre_softmax_attns: List[Tensor],
pre_softmax_attns: list[Tensor],
mask = None,
weight = 1.
):
Expand Down Expand Up @@ -611,7 +616,7 @@ def __init__(
dim_condition = default(dim_condition, dim)

self.ln = nn.LayerNorm(dim, elementwise_affine = False)
self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
self.to_gamma = LinearNoBias(dim_condition, dim)
nn.init.zeros_(self.to_gamma.weight)

def forward(self, x, *, condition):
Expand Down Expand Up @@ -666,7 +671,7 @@ def __init__(
self.scale = dim ** 0.5
dim_condition = default(dim_condition, dim)

self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
self.to_gamma = LinearNoBias(dim_condition, dim)
nn.init.zeros_(self.to_gamma.weight)

def forward(self, x, *, condition):
Expand Down Expand Up @@ -749,7 +754,7 @@ def forward(self, x, **kwargs):
feats_per_shift = x.shape[-1] // segments
splitted = x.split(feats_per_shift, dim = -1)
segments_to_shift, rest = splitted[:segments], splitted[segments:]
segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)]
x = torch.cat((*segments_to_shift, *rest), dim = -1)
return self.fn(x, **kwargs)

Expand Down Expand Up @@ -817,7 +822,7 @@ class ConcatCombine(Module):
def __init__(self, dim, prev_layer_ind):
super().__init__()
self.prev_layer_ind = prev_layer_ind
self.combine = nn.Linear(dim * 2, dim, bias = False)
self.combine = LinearNoBias(dim * 2, dim)

def forward(self, x, prev_layers: list[Tensor]):
skip = prev_layers[self.prev_layer_ind]
Expand Down Expand Up @@ -957,17 +962,17 @@ def __init__(
v_dim = value_dim_head * kv_heads
out_dim = value_dim_head * heads

self.to_q = nn.Linear(dim, q_dim, bias = False)
self.to_k = nn.Linear(dim_kv, k_dim, bias = False)
self.to_q = LinearNoBias(dim, q_dim)
self.to_k = LinearNoBias(dim_kv, k_dim)

# shared key / values, for further memory savings during inference

assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
self.to_v = nn.Linear(dim_kv, v_dim, bias = False) if not shared_kv else None
self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None

# relations projection from tp-attention

self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None

# add GLU gating for aggregated values, from alphafold2

Expand Down Expand Up @@ -1063,7 +1068,7 @@ def __init__(
# output dimension by default same as input, but can be overridden

dim_out = default(dim_out, dim)
self.to_out = nn.Sequential(nn.Linear(out_dim, dim_out * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim_out, bias = False)
self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)

# whether to rotate positions into values, for absolute positions in addition to relative

Expand Down Expand Up @@ -1109,7 +1114,7 @@ def forward(

q = rearrange(q, 'b n (h d) -> b h n d', h = h)

k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r))
k, v, r = tuple(maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v, r))

if exists(cache):
ck, cv = cache.cached_kv
Expand Down Expand Up @@ -1164,12 +1169,12 @@ def forward(

# i, j determined for relative positional bias, excluding memory key / values

i, j = map(lambda t: t.shape[-2], (q, k))
i, j = tuple(t.shape[-2] for t in (q, k))

# maybe append memory key / values

if num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
mem_k, mem_v = tuple(repeat(t, 'h n d -> b h n d', b = b) for t in (self.mem_k, self.mem_v))

if self.qk_norm:
mem_k = l2norm(mem_k)
Expand Down Expand Up @@ -1302,8 +1307,8 @@ def __init__(
rotary_xpos_scale_base = 512,
rotary_base_rescale_factor = 1.,
weight_tie_layers = False,
custom_layers: Tuple[str, ...] | None = None,
layers_execute_order: Tuple[int, ...] | None = None,
custom_layers: tuple[str, ...] | None = None,
layers_execute_order: tuple[int, ...] | None = None,
sandwich_coef = None,
par_ratio = None,
residual_attn = False,
Expand Down Expand Up @@ -1464,7 +1469,7 @@ def __init__(

if self.need_condition and adaptive_condition_mlp:
self.adaptive_mlp = nn.Sequential(
nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
LinearNoBias(dim_condition, dim_condition * dim_condition_mult),
nn.SiLU()
)

Expand Down Expand Up @@ -1635,7 +1640,7 @@ def forward(
return_hiddens = False,
rotary_pos_emb = None,
condition = None,
layers_execute_order: Tuple[int, ...] | None = None
layers_execute_order: tuple[int, ...] | None = None
):
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
Expand Down Expand Up @@ -1973,7 +1978,7 @@ def __init__(
num_tokens,
max_seq_len,
attn_layers: AttentionLayers,
embed_num_tokens: Dict[str, int] = dict(),
embed_num_tokens: dict[str, int] = dict(),
emb_dim = None,
max_mem_len = 0,
shift_mem_down = 0,
Expand All @@ -1996,6 +2001,8 @@ def __init__(
use_cls_token = False,
squeeze_out_last_dim = False,
token_emb: TokenEmbedding | None = None,
mixture_of_softmax = False,
mixture_of_softmax_k = 4,
):
super().__init__()

Expand Down Expand Up @@ -2050,7 +2057,7 @@ def __init__(
# maybe recycling

self.recycling = recycling
self.recycled_proj = nn.Linear(dim, dim, bias = False) if recycling else None
self.recycled_proj = LinearNoBias(dim, dim) if recycling else None

self.train_max_recycle_steps = train_max_recycle_steps

Expand All @@ -2066,21 +2073,37 @@ def __init__(

self.average_pool_embed = average_pool_embed

# output type

self.is_log_prob = mixture_of_softmax

self.to_mixture = None
self.combine_mixture = None

if mixture_of_softmax:
assert num_output_heads == 1

self.to_mixture = Sequential(
LinearNoBias(dim, dim * mixture_of_softmax_k),
Rearrange('... (k d) -> ... k d', k = mixture_of_softmax_k)
)

self.combine_mixture = LinearNoBias(dim, mixture_of_softmax_k)

# output head, usually to logits of num_tokens

logits_dim = default(logits_dim, num_tokens)

self.has_multiple_heads = False
self.has_multiple_heads = num_output_heads > 1

if return_only_embed:
self.to_logits = None
elif tie_embedding:
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
elif num_output_heads > 1:
self.has_multiple_heads = True
self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
else:
self.to_logits = nn.Linear(dim, logits_dim, bias = False)
self.to_logits = LinearNoBias(dim, logits_dim)

# memory tokens (like [cls]) from Memory Transformers paper

Expand Down Expand Up @@ -2124,7 +2147,7 @@ def forward(
pos = None,
prepend_embeds = None,
prepend_mask = None,
embed_ids: Dict[str, Tensor] = dict(),
embed_ids: dict[str, Tensor] = dict(),
sum_embeds = None,
return_attn_z_loss = False,
attn_z_loss_weight = 1e-4,
Expand Down Expand Up @@ -2281,6 +2304,14 @@ def forward(
if exists(self.cls_token):
x, _ = unpack(x, cls_packed_shape, 'b * d')

# handle expansion to mixture if needed (for mixture of softmax)

combine_mixture = None

if exists(self.to_mixture):
combine_mixture = self.combine_mixture(x).softmax(dim = -1)
x = self.to_mixture(x)

# projecting to logits

if not return_embeddings:
Expand All @@ -2289,6 +2320,14 @@ def forward(
else:
logits = self.to_logits(x)

# handle maybe combine mixture

if exists(combine_mixture):
with autocast('cuda', enabled = False):
prob = logits.softmax(dim = -1)
mos = einsum('... k d, ... k -> ... d', prob, combine_mixture)
logits = log(mos)

# maybe squeeze out last dimension of logits

if self.squeeze_out_last_dim:
Expand All @@ -2309,14 +2348,14 @@ def forward(
# aux loss

if return_attn_z_loss:
pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates]
intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
return_intermediates = True

if return_mems:
hiddens = intermediates.hiddens
new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]

if not return_intermediates:
return out, new_mems
Expand All @@ -2327,7 +2366,7 @@ def forward(
return out, intermediates

if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
return out, attn_maps

return out
Expand Down

0 comments on commit e6c488e

Please sign in to comment.