Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 6, 2024
1 parent 46bd7af commit f97a37b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.18.5"
version = "1.18.6"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
23 changes: 21 additions & 2 deletions vector_quantize_pytorch/residual_fsq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from math import log2
from functools import partial
from functools import partial, cache

from typing import List

Expand All @@ -9,6 +9,7 @@
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.amp import autocast
import torch.distributed as dist

from vector_quantize_pytorch.finite_scalar_quantization import FSQ

Expand All @@ -30,6 +31,12 @@ def default(val, d):
def round_up_multiple(num, mult):
return ceil(num / mult) * mult

# distributed helpers

@cache
def is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1

# main class

class ResidualFSQ(Module):
Expand Down Expand Up @@ -167,7 +174,19 @@ def forward(
# also prepare null indices

if should_quantize_dropout:
rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random

if exists(rand_quantize_dropout_fixed_seed):
# seed is manually passed in
rand = random.Random(rand_quantize_dropout_fixed_seed)

elif is_distributed():
# in distributed environment, synchronize a random seed value if not given
t = torch.tensor(random.randrange(10_000), device = device)
dropout_seed = dist.all_reduce(t).item()
rand = random.Random(dropout_seed)

else:
rand = random

rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)

Expand Down
23 changes: 21 additions & 2 deletions vector_quantize_pytorch/residual_lfq.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import random
from math import log2
from functools import partial
from functools import partial, cache

import torch
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.amp import autocast
import torch.distributed as dist

from vector_quantize_pytorch.lookup_free_quantization import LFQ

Expand All @@ -25,6 +26,12 @@ def default(val, d):
def round_up_multiple(num, mult):
return ceil(num / mult) * mult

# distributed helpers

@cache
def is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1

# main class

class ResidualLFQ(Module):
Expand Down Expand Up @@ -144,7 +151,19 @@ def forward(
# also prepare null indices and loss

if should_quantize_dropout:
rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random

if exists(rand_quantize_dropout_fixed_seed):
# seed is manually passed in
rand = random.Random(rand_quantize_dropout_fixed_seed)

elif is_distributed():
# in distributed environment, synchronize a random seed value if not given
t = torch.tensor(random.randrange(10_000), device = device)
dropout_seed = dist.all_reduce(t).item()
rand = random.Random(dropout_seed)

else:
rand = random

rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)

Expand Down

0 comments on commit f97a37b

Please sign in to comment.