From 1dca119eac6288078b9dfa1c8353592f8f3560b8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 9 Jan 2025 17:25:48 +0000 Subject: [PATCH] init --- torchrl/objectives/cql.py | 2 +- torchrl/objectives/crossq.py | 2 +- torchrl/objectives/decision_transformer.py | 2 +- torchrl/objectives/sac.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 375e3834dfc..6e056589a8c 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -892,7 +892,7 @@ def alpha_loss(self, tensordict: TensorDictBase) -> Tensor: @property def _alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) alpha = self.log_alpha.data.exp() return alpha diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 801180901a7..22e84673641 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -677,7 +677,7 @@ def alpha_loss(self, log_prob: Tensor) -> Tensor: @property def _alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) with torch.no_grad(): alpha = self.log_alpha.exp() diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 013e28713bf..a0d193acbfc 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -171,7 +171,7 @@ def _forward_value_estimator_keys(self, **kwargs): @property def alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) with torch.no_grad(): alpha = self.log_alpha.exp() diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index dafff17011e..eae6b7feb34 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -846,7 +846,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor: @property def _alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) with torch.no_grad(): alpha = self.log_alpha.exp() @@ -1374,7 +1374,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor: @property def _alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data = self.log_alpha.data.clamp( self.min_log_alpha, self.max_log_alpha )