Skip to content

Commit

Permalink
Merge branch 'main' into CML-log-uniform-sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
christophmluscher committed Jan 13, 2025
2 parents 6726c8f + b3fa661 commit 09a22ce
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 0 deletions.
2 changes: 2 additions & 0 deletions i6_models/parts/best_rq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .mask import *
from .quantizer import *
74 changes: 74 additions & 0 deletions i6_models/parts/best_rq/mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Optional, Tuple

import torch
import torch.nn as nn
import numpy as np

__all__ = ["RandomMask"]


class RandomMask(nn.Module):
"""
randomly mask out consecutive frames time dimension, the masked frames can be either
replaced with zeros or with learnable embeddings.
simplified version from Fairseq compute_mask_indices function,
C.f. https://github.com/facebookresearch/fairseq/blob/ecbf110e1eb43861214b05fa001eff584954f65a/fairseq/data/data_utils.py#L399
"""

def __init__(
self,
input_dim: int,
mask_replace_val: str,
mask_percentage: float,
mask_length: int,
):
"""
:param input_dim: number of feature dimension of input
:param mask_replace_val: the way to replace masked frames, either with zeros or lernable embeddings
:param mask_percentage: percentage of frames to be masked out
:param mask_length: the length of each mask span
"""
super().__init__()

assert mask_replace_val in ["lernable", "zero"], "not implemented yet"
if mask_replace_val == "lernable":
self.mask_emb = nn.Parameter(torch.FloatTensor(input_dim).uniform_())
elif mask_replace_val == "zero":
self.mask_emb = torch.zeros(input_dim)
self.mask_percentage = mask_percentage
self.mask_length = mask_length

def forward(
self,
tensor: torch.tensor,
padding_mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
ndim_batch, ndim_time, _ = tensor.size()

mask = torch.zeros((ndim_batch, ndim_time), dtype=torch.bool)

mask_idcs = []
for i in range(ndim_batch):
if padding_mask is not None:
seq_len = ndim_time - padding_mask[i].long().sum().item()
assert seq_len >= 0
else:
seq_len = ndim_time

num_mask = int(
# add a random number for probabilistic rounding
self.mask_percentage * seq_len / float(self.mask_length)
+ np.random.rand()
)

min_len = self.mask_length
if seq_len - min_len <= num_mask:
min_len = seq_len - num_mask - 1
mask_idc = np.random.choice(seq_len - min_len, num_mask, replace=False)

for j in mask_idc:
mask[i, j : j + self.mask_length] = True

tensor[mask] = self.mask_emb.to(tensor.device)

return tensor, torch.tensor(mask).to(tensor.device)
37 changes: 37 additions & 0 deletions i6_models/parts/best_rq/quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.linalg import vector_norm

__all__ = [
"RandomProjectionQuantizer",
]


class RandomProjectionQuantizer(nn.Module):
"""
implement the fixed random projection quantizer from BestRQ
C.f. https://arxiv.org/pdf/2202.01855 for theoretic background
code adapted from https://github.com/speechbrain/speechbrain/blob/16b6420d4ff23210cfca2e888be8853264e0cb17/speechbrain/nnet/quantisers.py#L127
"""

def __init__(self, input_dim, codebook_dim, codebook_num_vars):
"""
:param input_dim: number of feature dimension of input
:param codebook_dim: number of dimension for vocab in the codebook
:param codebook_num_vars: vocab size of the codebook
"""
super().__init__()

self.input_dim = input_dim

# projection matrix use Xavier initialization
P_init = torch.empty((input_dim, codebook_dim))
self.register_buffer("P", nn.init.xavier_uniform_(P_init))

# normalize random matrix for codebook
self.register_buffer("CB", F.normalize(torch.randn(codebook_num_vars, codebook_dim)))

def forward(self, x: torch.tensor) -> torch.tensor:
x = F.normalize(x @ self.P)
return vector_norm((self.CB.unsqueeze(1) - x.unsqueeze(1)), dim=-1).argmin(dim=1)

0 comments on commit 09a22ce

Please sign in to comment.