diff --git a/README.md b/README.md index e5c21d4b..6cbc30e3 100644 --- a/README.md +++ b/README.md @@ -2221,4 +2221,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17) } ``` +``` +@article{Yang2017BreakingTS, + title = {Breaking the Softmax Bottleneck: A High-Rank RNN Language Model}, + author = {Zhilin Yang and Zihang Dai and Ruslan Salakhutdinov and William W. Cohen}, + journal = {ArXiv}, + year = {2017}, + volume = {abs/1711.03953}, + url = {https://api.semanticscholar.org/CorpusID:26238954} +} +``` + *solve intelligence... then use that to solve everything else.* - Demis Hassabis diff --git a/setup.py b/setup.py index 779e71bd..a31fcb2b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.36.0', + version = '1.37.0', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/tests/test_x_transformers.py b/tests/test_x_transformers.py index e51d07c3..695f1013 100644 --- a/tests/test_x_transformers.py +++ b/tests/test_x_transformers.py @@ -259,3 +259,23 @@ def test_recycling(): model.eval() eval_logits = model(x, recycle_steps = 3) + +def test_mos(): + model = TransformerWrapper( + num_tokens = 20000, + max_seq_len = 1024, + mixture_of_softmax = True, + attn_layers = Decoder( + dim = 128, + depth = 6, + heads = 8 + ) + ) + + x = torch.randint(0, 20000, (2, 1024)) + + logits = model(x) + + model.eval() + + eval_logits = model(x, recycle_steps = 3) diff --git a/x_transformers/autoregressive_wrapper.py b/x_transformers/autoregressive_wrapper.py index 74e3f0f3..e8420182 100644 --- a/x_transformers/autoregressive_wrapper.py +++ b/x_transformers/autoregressive_wrapper.py @@ -317,7 +317,9 @@ def forward(self, x, return_outputs = False, **kwargs): **kwargs ) - loss = F.cross_entropy( + loss_fn = F.cross_entropy if not self.net.is_log_prob else F.nll_loss + + loss = loss_fn( rearrange(logits, 'b n c -> b c n'), target, ignore_index = ignore_index diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 9b726af0..d95c3b66 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Callable import math from random import random, randrange @@ -14,7 +15,6 @@ from collections import namedtuple from contextlib import nullcontext from dataclasses import dataclass -from typing import List, Dict, Tuple, Callable from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange @@ -28,14 +28,16 @@ @dataclass class LayerIntermediates: - hiddens: List[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture) + hiddens: list[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture) last_hidden: Tensor | None = None # very last hidden after all attention layers, after the final norm - attn_intermediates: List[Intermediates] | None = None - layer_hiddens: List[Tensor] | None = None + attn_intermediates: list[Intermediates] | None = None + layer_hiddens: list[Tensor] | None = None attn_z_loss: Tensor | None = None mems: Tensor | None = None memory_tokens: Tensor | None = None +LinearNoBias = partial(nn.Linear, bias = False) + # helpers def exists(val): @@ -92,6 +94,9 @@ def Sequential(*modules): # tensor helpers +def log(t, eps = 1e-20): + return t.clamp(min = eps).log() + def max_neg_value(tensor): return -torch.finfo(tensor.dtype).max @@ -114,7 +119,7 @@ def masked_mean(t, mask = None, dim = 1): den = mask.sum(dim = dim).clamp(min = 1.) return num / den -def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.): +def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.): if pad == (0, 0): return t @@ -131,7 +136,7 @@ def or_reduce(masks): # auxiliary loss helpers def calc_z_loss( - pre_softmax_attns: List[Tensor], + pre_softmax_attns: list[Tensor], mask = None, weight = 1. ): @@ -611,7 +616,7 @@ def __init__( dim_condition = default(dim_condition, dim) self.ln = nn.LayerNorm(dim, elementwise_affine = False) - self.to_gamma = nn.Linear(dim_condition, dim, bias = False) + self.to_gamma = LinearNoBias(dim_condition, dim) nn.init.zeros_(self.to_gamma.weight) def forward(self, x, *, condition): @@ -666,7 +671,7 @@ def __init__( self.scale = dim ** 0.5 dim_condition = default(dim_condition, dim) - self.to_gamma = nn.Linear(dim_condition, dim, bias = False) + self.to_gamma = LinearNoBias(dim_condition, dim) nn.init.zeros_(self.to_gamma.weight) def forward(self, x, *, condition): @@ -749,7 +754,7 @@ def forward(self, x, **kwargs): feats_per_shift = x.shape[-1] // segments splitted = x.split(feats_per_shift, dim = -1) segments_to_shift, rest = splitted[:segments], splitted[segments:] - segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts))) + segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)] x = torch.cat((*segments_to_shift, *rest), dim = -1) return self.fn(x, **kwargs) @@ -817,7 +822,7 @@ class ConcatCombine(Module): def __init__(self, dim, prev_layer_ind): super().__init__() self.prev_layer_ind = prev_layer_ind - self.combine = nn.Linear(dim * 2, dim, bias = False) + self.combine = LinearNoBias(dim * 2, dim) def forward(self, x, prev_layers: list[Tensor]): skip = prev_layers[self.prev_layer_ind] @@ -957,17 +962,17 @@ def __init__( v_dim = value_dim_head * kv_heads out_dim = value_dim_head * heads - self.to_q = nn.Linear(dim, q_dim, bias = False) - self.to_k = nn.Linear(dim_kv, k_dim, bias = False) + self.to_q = LinearNoBias(dim, q_dim) + self.to_k = LinearNoBias(dim_kv, k_dim) # shared key / values, for further memory savings during inference assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values' - self.to_v = nn.Linear(dim_kv, v_dim, bias = False) if not shared_kv else None + self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None # relations projection from tp-attention - self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None + self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None # add GLU gating for aggregated values, from alphafold2 @@ -1063,7 +1068,7 @@ def __init__( # output dimension by default same as input, but can be overridden dim_out = default(dim_out, dim) - self.to_out = nn.Sequential(nn.Linear(out_dim, dim_out * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim_out, bias = False) + self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out) # whether to rotate positions into values, for absolute positions in addition to relative @@ -1109,7 +1114,7 @@ def forward( q = rearrange(q, 'b n (h d) -> b h n d', h = h) - k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r)) + k, v, r = tuple(maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v, r)) if exists(cache): ck, cv = cache.cached_kv @@ -1164,12 +1169,12 @@ def forward( # i, j determined for relative positional bias, excluding memory key / values - i, j = map(lambda t: t.shape[-2], (q, k)) + i, j = tuple(t.shape[-2] for t in (q, k)) # maybe append memory key / values if num_mem_kv > 0: - mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v)) + mem_k, mem_v = tuple(repeat(t, 'h n d -> b h n d', b = b) for t in (self.mem_k, self.mem_v)) if self.qk_norm: mem_k = l2norm(mem_k) @@ -1302,8 +1307,8 @@ def __init__( rotary_xpos_scale_base = 512, rotary_base_rescale_factor = 1., weight_tie_layers = False, - custom_layers: Tuple[str, ...] | None = None, - layers_execute_order: Tuple[int, ...] | None = None, + custom_layers: tuple[str, ...] | None = None, + layers_execute_order: tuple[int, ...] | None = None, sandwich_coef = None, par_ratio = None, residual_attn = False, @@ -1464,7 +1469,7 @@ def __init__( if self.need_condition and adaptive_condition_mlp: self.adaptive_mlp = nn.Sequential( - nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False), + LinearNoBias(dim_condition, dim_condition * dim_condition_mult), nn.SiLU() ) @@ -1635,7 +1640,7 @@ def forward( return_hiddens = False, rotary_pos_emb = None, condition = None, - layers_execute_order: Tuple[int, ...] | None = None + layers_execute_order: tuple[int, ...] | None = None ): assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True' assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa' @@ -1973,7 +1978,7 @@ def __init__( num_tokens, max_seq_len, attn_layers: AttentionLayers, - embed_num_tokens: Dict[str, int] = dict(), + embed_num_tokens: dict[str, int] = dict(), emb_dim = None, max_mem_len = 0, shift_mem_down = 0, @@ -1996,6 +2001,8 @@ def __init__( use_cls_token = False, squeeze_out_last_dim = False, token_emb: TokenEmbedding | None = None, + mixture_of_softmax = False, + mixture_of_softmax_k = 4, ): super().__init__() @@ -2050,7 +2057,7 @@ def __init__( # maybe recycling self.recycling = recycling - self.recycled_proj = nn.Linear(dim, dim, bias = False) if recycling else None + self.recycled_proj = LinearNoBias(dim, dim) if recycling else None self.train_max_recycle_steps = train_max_recycle_steps @@ -2066,21 +2073,37 @@ def __init__( self.average_pool_embed = average_pool_embed + # output type + + self.is_log_prob = mixture_of_softmax + + self.to_mixture = None + self.combine_mixture = None + + if mixture_of_softmax: + assert num_output_heads == 1 + + self.to_mixture = Sequential( + LinearNoBias(dim, dim * mixture_of_softmax_k), + Rearrange('... (k d) -> ... k d', k = mixture_of_softmax_k) + ) + + self.combine_mixture = LinearNoBias(dim, mixture_of_softmax_k) + # output head, usually to logits of num_tokens logits_dim = default(logits_dim, num_tokens) - self.has_multiple_heads = False + self.has_multiple_heads = num_output_heads > 1 if return_only_embed: self.to_logits = None elif tie_embedding: self.to_logits = lambda t: t @ self.token_emb.emb.weight.t() elif num_output_heads > 1: - self.has_multiple_heads = True - self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)]) + self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)]) else: - self.to_logits = nn.Linear(dim, logits_dim, bias = False) + self.to_logits = LinearNoBias(dim, logits_dim) # memory tokens (like [cls]) from Memory Transformers paper @@ -2124,7 +2147,7 @@ def forward( pos = None, prepend_embeds = None, prepend_mask = None, - embed_ids: Dict[str, Tensor] = dict(), + embed_ids: dict[str, Tensor] = dict(), sum_embeds = None, return_attn_z_loss = False, attn_z_loss_weight = 1e-4, @@ -2281,6 +2304,14 @@ def forward( if exists(self.cls_token): x, _ = unpack(x, cls_packed_shape, 'b * d') + # handle expansion to mixture if needed (for mixture of softmax) + + combine_mixture = None + + if exists(self.to_mixture): + combine_mixture = self.combine_mixture(x).softmax(dim = -1) + x = self.to_mixture(x) + # projecting to logits if not return_embeddings: @@ -2289,6 +2320,14 @@ def forward( else: logits = self.to_logits(x) + # handle maybe combine mixture + + if exists(combine_mixture): + with autocast('cuda', enabled = False): + prob = logits.softmax(dim = -1) + mos = einsum('... k d, ... k -> ... d', prob, combine_mixture) + logits = log(mos) + # maybe squeeze out last dimension of logits if self.squeeze_out_last_dim: @@ -2309,14 +2348,14 @@ def forward( # aux loss if return_attn_z_loss: - pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates)) + pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates] intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight) return_intermediates = True if return_mems: hiddens = intermediates.hiddens - new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens - new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens + new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems] if not return_intermediates: return out, new_mems @@ -2327,7 +2366,7 @@ def forward( return out, intermediates if return_attn: - attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates] return out, attn_maps return out