From 0d8a9e398614c8515fb2af1d78b84537e9903048 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20M=2E=20L=C3=BCscher?= Date: Thu, 19 Dec 2024 11:23:59 +0100 Subject: [PATCH] improve code * add device param * all in torch Co-authored-by: Albert Zeyer --- i6_models/samplers/log_uniform.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/i6_models/samplers/log_uniform.py b/i6_models/samplers/log_uniform.py index 7973ab59..462fa20a 100644 --- a/i6_models/samplers/log_uniform.py +++ b/i6_models/samplers/log_uniform.py @@ -8,20 +8,19 @@ class LogUniformSampler(nn.Module): - def __init__(self, num_classes): + def __init__(self, num_classes: int, *, device: Optional[torch.device] = None): super().__init__() # assumes count-sorted vocabulary, descending self.num_classes = num_classes # approximately zipf distribution - self._distribution = [ - (math.log1p(w + 1) - math.log1p(w)) / math.log1p(self.num_classes) for w in range(self.num_classes) - ] - self._distribution = torch.tensor(self._distribution).clamp(min=1e-10) + ws = torch.arange(self.num_classes, dtype=torch.get_default_dtype(), device=device) + self._distribution = (torch.log1p(ws + 1) - torch.log1p(ws)) / torch.log1p(torch.tensor(self.num_classes)) + self._distribution.clamp_(min=1e-10) self._distribution /= self._distribution.sum() - self._cat_sampler = torch.distributions.categorical.Categorical(probs=self._distribution.cuda()) + self._cat_sampler = torch.distributions.categorical.Categorical(probs=self._distribution) def sample(self, num_samples): return self._cat_sampler.sample(torch.Size([num_samples]))