diff --git a/i6_models/parts/best_rq/__init__.py b/i6_models/parts/best_rq/__init__.py new file mode 100644 index 00000000..fb9a81db --- /dev/null +++ b/i6_models/parts/best_rq/__init__.py @@ -0,0 +1,2 @@ +from .mask import * +from .quantizer import * diff --git a/i6_models/parts/best_rq/mask.py b/i6_models/parts/best_rq/mask.py new file mode 100644 index 00000000..2e88479b --- /dev/null +++ b/i6_models/parts/best_rq/mask.py @@ -0,0 +1,74 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import numpy as np + +__all__ = ["RandomMask"] + + +class RandomMask(nn.Module): + """ + randomly mask out consecutive frames time dimension, the masked frames can be either + replaced with zeros or with learnable embeddings. + simplified version from Fairseq compute_mask_indices function, + C.f. https://github.com/facebookresearch/fairseq/blob/ecbf110e1eb43861214b05fa001eff584954f65a/fairseq/data/data_utils.py#L399 + """ + + def __init__( + self, + input_dim: int, + mask_replace_val: str, + mask_percentage: float, + mask_length: int, + ): + """ + :param input_dim: number of feature dimension of input + :param mask_replace_val: the way to replace masked frames, either with zeros or lernable embeddings + :param mask_percentage: percentage of frames to be masked out + :param mask_length: the length of each mask span + """ + super().__init__() + + assert mask_replace_val in ["lernable", "zero"], "not implemented yet" + if mask_replace_val == "lernable": + self.mask_emb = nn.Parameter(torch.FloatTensor(input_dim).uniform_()) + elif mask_replace_val == "zero": + self.mask_emb = torch.zeros(input_dim) + self.mask_percentage = mask_percentage + self.mask_length = mask_length + + def forward( + self, + tensor: torch.tensor, + padding_mask: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + ndim_batch, ndim_time, _ = tensor.size() + + mask = torch.zeros((ndim_batch, ndim_time), dtype=torch.bool) + + mask_idcs = [] + for i in range(ndim_batch): + if padding_mask is not None: + seq_len = ndim_time - padding_mask[i].long().sum().item() + assert seq_len >= 0 + else: + seq_len = ndim_time + + num_mask = int( + # add a random number for probabilistic rounding + self.mask_percentage * seq_len / float(self.mask_length) + + np.random.rand() + ) + + min_len = self.mask_length + if seq_len - min_len <= num_mask: + min_len = seq_len - num_mask - 1 + mask_idc = np.random.choice(seq_len - min_len, num_mask, replace=False) + + for j in mask_idc: + mask[i, j : j + self.mask_length] = True + + tensor[mask] = self.mask_emb.to(tensor.device) + + return tensor, torch.tensor(mask).to(tensor.device) diff --git a/i6_models/parts/best_rq/quantizer.py b/i6_models/parts/best_rq/quantizer.py new file mode 100644 index 00000000..8633eb42 --- /dev/null +++ b/i6_models/parts/best_rq/quantizer.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.linalg import vector_norm + +__all__ = [ + "RandomProjectionQuantizer", +] + + +class RandomProjectionQuantizer(nn.Module): + """ + implement the fixed random projection quantizer from BestRQ + C.f. https://arxiv.org/pdf/2202.01855 for theoretic background + code adapted from https://github.com/speechbrain/speechbrain/blob/16b6420d4ff23210cfca2e888be8853264e0cb17/speechbrain/nnet/quantisers.py#L127 + """ + + def __init__(self, input_dim, codebook_dim, codebook_num_vars): + """ + :param input_dim: number of feature dimension of input + :param codebook_dim: number of dimension for vocab in the codebook + :param codebook_num_vars: vocab size of the codebook + """ + super().__init__() + + self.input_dim = input_dim + + # projection matrix use Xavier initialization + P_init = torch.empty((input_dim, codebook_dim)) + self.register_buffer("P", nn.init.xavier_uniform_(P_init)) + + # normalize random matrix for codebook + self.register_buffer("CB", F.normalize(torch.randn(codebook_num_vars, codebook_dim))) + + def forward(self, x: torch.tensor) -> torch.tensor: + x = F.normalize(x @ self.P) + return vector_norm((self.CB.unsqueeze(1) - x.unsqueeze(1)), dim=-1).argmin(dim=1)