diff --git a/i6_models/samplers/log_uniform.py b/i6_models/samplers/log_uniform.py index 7973ab59..462fa20a 100644 --- a/i6_models/samplers/log_uniform.py +++ b/i6_models/samplers/log_uniform.py @@ -8,20 +8,19 @@ class LogUniformSampler(nn.Module): - def __init__(self, num_classes): + def __init__(self, num_classes: int, *, device: Optional[torch.device] = None): super().__init__() # assumes count-sorted vocabulary, descending self.num_classes = num_classes # approximately zipf distribution - self._distribution = [ - (math.log1p(w + 1) - math.log1p(w)) / math.log1p(self.num_classes) for w in range(self.num_classes) - ] - self._distribution = torch.tensor(self._distribution).clamp(min=1e-10) + ws = torch.arange(self.num_classes, dtype=torch.get_default_dtype(), device=device) + self._distribution = (torch.log1p(ws + 1) - torch.log1p(ws)) / torch.log1p(torch.tensor(self.num_classes)) + self._distribution.clamp_(min=1e-10) self._distribution /= self._distribution.sum() - self._cat_sampler = torch.distributions.categorical.Categorical(probs=self._distribution.cuda()) + self._cat_sampler = torch.distributions.categorical.Categorical(probs=self._distribution) def sample(self, num_samples): return self._cat_sampler.sample(torch.Size([num_samples]))