Skip to content

Commit

Permalink
cleanup rotation trick
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 11, 2024
1 parent 089011b commit 56f20dc
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 29 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.18.0"
version = "1.18.1"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
7 changes: 5 additions & 2 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ def exists(v):
return v is not None

@pytest.mark.parametrize('use_cosine_sim', (True, False))
@pytest.mark.parametrize('rotation_trick', (True, False))
def test_vq(
use_cosine_sim
use_cosine_sim,
rotation_trick
):
from vector_quantize_pytorch import VectorQuantize

Expand All @@ -15,7 +17,8 @@ def test_vq(
codebook_size = 512, # codebook size
decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster
commitment_weight = 1., # the weight on the commitment loss
use_cosine_sim = use_cosine_sim
use_cosine_sim = use_cosine_sim,
rotation_trick = rotation_trick
)

x = torch.randn(1, 1024, 256)
Expand Down
63 changes: 37 additions & 26 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ def noop(*args, **kwargs):
def identity(t):
return t

def l2norm(t):
return F.normalize(t, p = 2, dim = -1)
def l2norm(t, dim = -1, eps = 1e-6):
return F.normalize(t, p = 2, dim = dim, eps = eps)

def safe_div(num, den, eps = 1e-6):
return num / den.clamp(min = eps)

def Sequential(*modules):
modules = [*filter(exists, modules)]
Expand Down Expand Up @@ -73,6 +76,19 @@ def lens_to_mask(lens, max_length):
seq = torch.arange(max_length, device = lens.device)
return seq < lens[:, None]

def efficient_rotation_trick_transform(u, q, e):
"""
4.2 in https://arxiv.org/abs/2410.06424
"""
e = rearrange(e, 'b d -> b 1 d')
w = l2norm(u + q, dim = 1).detach()

return (
e -
2 * (e @ rearrange(w, 'b d -> b d 1') @ rearrange(w, 'b d -> b 1 d')) +
2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach())
)

def uniform_init(*shape):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
Expand Down Expand Up @@ -811,7 +827,7 @@ def __init__(
stochastic_sample_codes = False,
sample_codebook_temp = 1.,
straight_through = False,
rotation_trick = True, # Propagate grads through VQ layer w/ rotation trick: https://arxiv.org/abs/2410.06424.
rotation_trick = True, # Propagate grads through VQ layer w/ rotation trick: https://arxiv.org/abs/2410.06424 by @cfifty
reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all
sync_codebook = None,
sync_affine_param = False,
Expand Down Expand Up @@ -946,13 +962,6 @@ def codebook(self, codes):

self._codebook.embed.copy_(codes)

@staticmethod
def rotation_trick_transform(u, q, e):
w = ((u + q) / torch.norm(u + q, dim=1, keepdim=True)).detach()
e = e - 2 * torch.bmm(torch.bmm(e, w.unsqueeze(-1)), w.unsqueeze(1)) + 2 * torch.bmm(
torch.bmm(e, u.unsqueeze(-1).detach()), q.unsqueeze(1).detach())
return e

def get_codes_from_indices(self, indices):
codebook = self.codebook
is_multiheaded = codebook.ndim > 2
Expand Down Expand Up @@ -1103,23 +1112,25 @@ def forward(

commit_quantize = maybe_detach(quantize)

# Use the rotation trick (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
if self.rotation_trick:
init_shape = x.shape
x = x.reshape(-1, init_shape[-1])
quantize = quantize.reshape(-1, init_shape[-1])

eps = 1e-6 # For numerical stability if any vector is close to 0 norm.
rot_quantize = self.rotation_trick_transform(
x / (torch.norm(x, dim=1, keepdim=True) + eps),
quantize / (torch.norm(quantize, dim=1, keepdim=True) + eps),
x.unsqueeze(1)).squeeze()
quantize = rot_quantize * (torch.norm(quantize, dim=1, keepdim=True)
/ (torch.norm(x, dim=1, keepdim=True) + 1e-6)).detach()

x = x.reshape(init_shape)
quantize = quantize.reshape(init_shape)
else: # Use STE to get gradients through VQ layer.
# rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
x, inverse = pack_one(x, '* d')
quantize, _ = pack_one(quantize, '* d')

norm_x = x.norm(dim = -1, keepdim = True)
norm_quantize = quantize.norm(dim = -1, keepdim = True)

rot_quantize = efficient_rotation_trick_transform(
safe_div(x, norm_x),
safe_div(quantize, norm_quantize),
x
).squeeze()

quantize = rot_quantize * safe_div(norm_quantize, norm_x).detach()

x, quantize = inverse(x), inverse(quantize)
else:
# standard STE to get gradients through VQ layer.
quantize = x + (quantize - x).detach()

if self.sync_update_v > 0.:
Expand Down

0 comments on commit 56f20dc

Please sign in to comment.