diff --git a/i6_models/parts/samplers/log_uniform.py b/i6_models/parts/samplers/log_uniform.py index 15feabe8..764ac1e4 100644 --- a/i6_models/parts/samplers/log_uniform.py +++ b/i6_models/parts/samplers/log_uniform.py @@ -37,7 +37,7 @@ def sample(self, num_samples: int) -> torch.Tensor: Returns a random tensor in the size of [num_samples]. :param num_samples: number of samples. - :return: + :return: [num_samples] """ return self._cat_sampler.sample(torch.Size([num_samples])) @@ -46,6 +46,6 @@ def log_prob(self, indices: torch.Tensor) -> torch.Tensor: Return log-probability of the given indices in the size of [B x T] :param indices: the ground truth target labels as indices. - :return: + :return: [B x T] """ return self._cat_sampler.log_prob(indices)