From fbe20d4dc5e1d2498771a102bb88ace00811bb02 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 10 Jan 2025 09:45:44 +0100 Subject: [PATCH] warning --- torchrl/objectives/ppo.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 5e5d4ea006e..3d1b3bd5088 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -47,6 +47,7 @@ TDLambdaEstimator, VTrace, ) +from yaml import warnings class PPOLoss(LossModule): @@ -613,6 +614,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: + if advantage.numel() > tensordict.batch_size.numel() and not len( + self.normalize_advantage_exclude_dims + ): + warnings.warn( + "You requested advantage normalization and the advantage key has more dimensions" + " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` " + "if you want to keep any dimension independent while computing normalization statistics. " + "If you are working in multi-agent/multi-objective settings this is highly suggested." + ) advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) log_weight, dist, kl_approx = self._log_weight(tensordict) @@ -881,6 +891,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: + if advantage.numel() > tensordict.batch_size.numel() and not len( + self.normalize_advantage_exclude_dims + ): + warnings.warn( + "You requested advantage normalization and the advantage key has more dimensions" + " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` " + "if you want to keep any dimension independent while computing normalization statistics. " + "If you are working in multi-agent/multi-objective settings this is highly suggested." + ) advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) log_weight, dist, kl_approx = self._log_weight(tensordict) @@ -1164,6 +1183,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: ) advantage = tensordict_copy.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: + if advantage.numel() > tensordict.batch_size.numel() and not len( + self.normalize_advantage_exclude_dims + ): + warnings.warn( + "You requested advantage normalization and the advantage key has more dimensions" + " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` " + "if you want to keep any dimension independent while computing normalization statistics. " + "If you are working in multi-agent/multi-objective settings this is highly suggested." + ) advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) log_weight, dist, kl_approx = self._log_weight(tensordict_copy)