Skip to content

Commit

Permalink
warning
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jan 10, 2025
1 parent 9cfe213 commit fbe20d4
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
TDLambdaEstimator,
VTrace,
)
from yaml import warnings


class PPOLoss(LossModule):
Expand Down Expand Up @@ -613,6 +614,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
advantage = tensordict.get(self.tensor_keys.advantage)
if self.normalize_advantage and advantage.numel() > 1:
if advantage.numel() > tensordict.batch_size.numel() and not len(
self.normalize_advantage_exclude_dims
):
warnings.warn(
"You requested advantage normalization and the advantage key has more dimensions"
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
"if you want to keep any dimension independent while computing normalization statistics. "
"If you are working in multi-agent/multi-objective settings this is highly suggested."
)
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)

log_weight, dist, kl_approx = self._log_weight(tensordict)
Expand Down Expand Up @@ -881,6 +891,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
advantage = tensordict.get(self.tensor_keys.advantage)
if self.normalize_advantage and advantage.numel() > 1:
if advantage.numel() > tensordict.batch_size.numel() and not len(
self.normalize_advantage_exclude_dims
):
warnings.warn(
"You requested advantage normalization and the advantage key has more dimensions"
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
"if you want to keep any dimension independent while computing normalization statistics. "
"If you are working in multi-agent/multi-objective settings this is highly suggested."
)
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)

log_weight, dist, kl_approx = self._log_weight(tensordict)
Expand Down Expand Up @@ -1164,6 +1183,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
)
advantage = tensordict_copy.get(self.tensor_keys.advantage)
if self.normalize_advantage and advantage.numel() > 1:
if advantage.numel() > tensordict.batch_size.numel() and not len(
self.normalize_advantage_exclude_dims
):
warnings.warn(
"You requested advantage normalization and the advantage key has more dimensions"
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
"if you want to keep any dimension independent while computing normalization statistics. "
"If you are working in multi-agent/multi-objective settings this is highly suggested."
)
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)

log_weight, dist, kl_approx = self._log_weight(tensordict_copy)
Expand Down

0 comments on commit fbe20d4

Please sign in to comment.