diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index c48d8168887..d2ffba30686 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -389,6 +389,17 @@ def sample( ) -> torch.Tensor: ... + @property + def deterministic_sample(self): + return self.mode + + @property + def mode(self) -> torch.Tensor: + if hasattr(self, "logits"): + return (self.logits == self.logits.max(-1, True)[0]).to(torch.long) + else: + return (self.probs == self.probs.max(-1, True)[0]).to(torch.long) + def log_prob(self, value: torch.Tensor) -> torch.Tensor: return super().log_prob(value.argmax(dim=-1))