From f2b721aae7c33a34155a0db6135a559b75513334 Mon Sep 17 00:00:00 2001 From: tasdep Date: Fri, 21 Jun 2024 12:58:13 +0200 Subject: [PATCH 1/2] disable empirical normalizer updates on resume training --- rsl_rl/runners/on_policy_runner.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 9e0a459..5f2458a 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -42,8 +42,12 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): self.save_interval = self.cfg["save_interval"] self.empirical_normalization = self.cfg["empirical_normalization"] if self.empirical_normalization: - self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device) - self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device) + if train_cfg.get("resume") == True: + until = 0 + else: + until = 1.0e8 + self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=until).to(self.device) + self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=until).to(self.device) else: self.obs_normalizer = torch.nn.Identity() # no normalization self.critic_obs_normalizer = torch.nn.Identity() # no normalization From 40e669de47a54067b2077481817f4b9296aa6b0e Mon Sep 17 00:00:00 2001 From: tasdep Date: Fri, 21 Jun 2024 15:18:34 +0200 Subject: [PATCH 2/2] Revert prev and register count in the state_dict so it is saved and loaded with the module --- rsl_rl/modules/normalizer.py | 2 +- rsl_rl/runners/on_policy_runner.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/rsl_rl/modules/normalizer.py b/rsl_rl/modules/normalizer.py index 771efcf..c80ad10 100644 --- a/rsl_rl/modules/normalizer.py +++ b/rsl_rl/modules/normalizer.py @@ -26,7 +26,7 @@ def __init__(self, shape, eps=1e-2, until=None): self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0)) self.register_buffer("_var", torch.ones(shape).unsqueeze(0)) self.register_buffer("_std", torch.ones(shape).unsqueeze(0)) - self.count = 0 + self.register_buffer("count", torch.tensor(0, dtype=torch.long)) @property def mean(self): diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 5f2458a..9e0a459 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -42,12 +42,8 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): self.save_interval = self.cfg["save_interval"] self.empirical_normalization = self.cfg["empirical_normalization"] if self.empirical_normalization: - if train_cfg.get("resume") == True: - until = 0 - else: - until = 1.0e8 - self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=until).to(self.device) - self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=until).to(self.device) + self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device) + self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device) else: self.obs_normalizer = torch.nn.Identity() # no normalization self.critic_obs_normalizer = torch.nn.Identity() # no normalization