Skip to content

Commit

Permalink
add compatibility for the residual VQ proposed in TIGER https://arxiv…
Browse files Browse the repository at this point in the history
….org/abs/2305.05065, for building recommendation systems
  • Loading branch information
lucidrains committed Nov 7, 2024
1 parent 35a8a41 commit a4bef4d
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 13 deletions.
36 changes: 35 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,29 @@ quantized, indices, commit_loss = residual_vq(x)
# (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)
```


<a href="https://arxiv.org/abs/2305.05065">This paper</a> out of Google Deepmind claims that residual vector quantization can induce hierarchical semantic ids for building a recommender system. In their scheme, they use increasing number of codes across depth for it to work. This repository supports that scheme as so

```python
import torch
from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ(
dim = 2,
codebook_size = (5, 128, 256), # from top most hierarchy to lowest, 5 codes, 128 codes, then 256 codes
)

x = torch.randn(2, 2, 2)

residual_vq.train()

quantized, indices, commit_loss = residual_vq(x, freeze_codebook = True)

quantized_out = residual_vq.get_output_from_indices(indices)

assert torch.allclose(quantized, quantized_out, atol = 1e-5)
```

## Initialization

The SoundStream paper proposes that the codebook should be initialized by the kmeans centroids of the first batch. You can easily turn on this feature with one flag `kmeans_init = True`, for either `VectorQuantize` or `ResidualVQ` class
Expand Down Expand Up @@ -713,4 +736,15 @@ assert loss.item() >= 0
volume = {abs/2410.06424},
url = {https://api.semanticscholar.org/CorpusID:273229218}
}
```
```

```bibtex
@article{Rajput2023RecommenderSW,
title = {Recommender Systems with Generative Retrieval},
author = {Shashank Rajput and Nikhil Mehta and Anima Singh and Raghunandan H. Keshavan and Trung Hieu Vu and Lukasz Heldt and Lichan Hong and Yi Tay and Vinh Q. Tran and Jonah Samost and Maciej Kula and Ed H. Chi and Maheswaran Sathiamoorthy},
journal = {ArXiv},
year = {2023},
volume = {abs/2305.05065},
url = {https://api.semanticscholar.org/CorpusID:258564854}
}
```
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.18.8"
version = "1.19.0"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
18 changes: 18 additions & 0 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,24 @@ def test_rq():
x = torch.randn(1, 1024, 512)
indices = quantizer(x)

def test_tiger():
from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ(
dim = 2,
codebook_size = (5, 128, 256),
)

x = torch.randn(2, 2, 2)

residual_vq.train()

quantized, indices, commit_loss = residual_vq(x, freeze_codebook = True)

quantized_out = residual_vq.get_output_from_indices(indices) # pass your indices into here, but the indices must come during .eval(), as during training some of the indices are dropped out (-1)

assert torch.allclose(quantized, quantized_out, atol = 1e-5)

def test_fsq():
from vector_quantize_pytorch import FSQ

Expand Down
50 changes: 39 additions & 11 deletions vector_quantize_pytorch/residual_vq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations
from typing import List

import random
from math import ceil
Expand Down Expand Up @@ -28,6 +27,12 @@ def first(it):
def default(val, d):
return val if exists(val) else d

def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)

def unique(arr):
return list({*arr})

def round_up_multiple(num, mult):
return ceil(num / mult) * mult

Expand Down Expand Up @@ -110,7 +115,8 @@ def __init__(
self,
*,
dim,
num_quantizers,
num_quantizers: int | None = None,
codebook_size: int | tuple[int, ...],
codebook_dim = None,
shared_codebook = False,
heads = 1,
Expand All @@ -124,6 +130,8 @@ def __init__(
):
super().__init__()
assert heads == 1, 'residual vq is not compatible with multi-headed codes'
assert exists(num_quantizers) or isinstance(codebook_size, tuple)

codebook_dim = default(codebook_dim, dim)
codebook_input_dim = codebook_dim * heads

Expand All @@ -132,8 +140,6 @@ def __init__(
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.has_projections = requires_projection

self.num_quantizers = num_quantizers

self.accept_image_fmap = accept_image_fmap

self.implicit_neural_codebook = implicit_neural_codebook
Expand All @@ -150,7 +156,21 @@ def __init__(
manual_in_place_optimizer_update = True
)

self.layers = ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **vq_kwargs) for _ in range(num_quantizers)])
# take care of maybe different codebook sizes across depth, used in TIGER paper https://arxiv.org/abs/2305.05065

codebook_sizes = cast_tuple(codebook_size, num_quantizers)

num_quantizers = len(codebook_sizes)
self.num_quantizers = num_quantizers

assert len(codebook_sizes) == num_quantizers

self.codebook_sizes = codebook_sizes
self.uniform_codebook_size = len(unique(codebook_sizes)) == 1

# define vq across layers

self.layers = ModuleList([VectorQuantize(dim = codebook_dim, codebook_size = layer_codebook_size, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **vq_kwargs) for layer_codebook_size in codebook_sizes])

assert all([not vq.has_projections for vq in self.layers])

Expand All @@ -167,6 +187,8 @@ def __init__(

if implicit_neural_codebook:
self.mlps = ModuleList([MLP(dim = codebook_dim, l2norm_output = first(self.layers).use_cosine_sim, **mlp_kwargs) for _ in range(num_quantizers - 1)])
else:
self.mlps = (None,) * (num_quantizers - 1)

# sharing codebook logic

Expand All @@ -175,6 +197,8 @@ def __init__(
if not shared_codebook:
return

assert self.uniform_codebook_size

first_vq, *rest_vq = self.layers
codebook = first_vq._codebook

Expand All @@ -192,8 +216,13 @@ def codebook_dim(self):
@property
def codebooks(self):
codebooks = [layer._codebook.embed for layer in self.layers]
codebooks = torch.stack(codebooks, dim = 0)
codebooks = rearrange(codebooks, 'q 1 c d -> q c d')

codebooks = tuple(rearrange(codebook, '1 ... -> ...') for codebook in codebooks)

if not self.uniform_codebook_size:
return codebooks

codebooks = torch.stack(codebooks)
return codebooks

def get_codes_from_indices(self, indices):
Expand All @@ -216,13 +245,12 @@ def get_codes_from_indices(self, indices):
mask = indices == -1.
indices = indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later

if not self.implicit_neural_codebook:
# gather all the codes
if not self.implicit_neural_codebook and self.uniform_codebook_size:

all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)

else:
# else if using implicit neural codebook, codes will need to be derived layer by layer
# else if using implicit neural codebook, or non uniform codebook sizes, codes will need to be derived layer by layer

code_transform_mlps = (None, *self.mlps)

Expand Down Expand Up @@ -261,7 +289,7 @@ def forward(
self,
x,
mask = None,
indices: Tensor | List[Tensor] | None = None,
indices: Tensor | list[Tensor] | None = None,
return_all_codes = False,
sample_codebook_temp = None,
freeze_codebook = False,
Expand Down

0 comments on commit a4bef4d

Please sign in to comment.