diff --git a/rsl_rl/modules/normalizer.py b/rsl_rl/modules/normalizer.py index 771efcf..c80ad10 100644 --- a/rsl_rl/modules/normalizer.py +++ b/rsl_rl/modules/normalizer.py @@ -26,7 +26,7 @@ def __init__(self, shape, eps=1e-2, until=None): self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0)) self.register_buffer("_var", torch.ones(shape).unsqueeze(0)) self.register_buffer("_std", torch.ones(shape).unsqueeze(0)) - self.count = 0 + self.register_buffer("count", torch.tensor(0, dtype=torch.long)) @property def mean(self): diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 5f2458a..b475620 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -42,12 +42,8 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): self.save_interval = self.cfg["save_interval"] self.empirical_normalization = self.cfg["empirical_normalization"] if self.empirical_normalization: - if train_cfg.get("resume") == True: - until = 0 - else: - until = 1.0e8 - self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=until).to(self.device) - self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=until).to(self.device) + self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device) + self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device) else: self.obs_normalizer = torch.nn.Identity() # no normalization self.critic_obs_normalizer = torch.nn.Identity() # no normalization @@ -264,6 +260,7 @@ def save(self, path, infos=None): if self.empirical_normalization: saved_dict["obs_norm_state_dict"] = self.obs_normalizer.state_dict() saved_dict["critic_obs_norm_state_dict"] = self.critic_obs_normalizer.state_dict() + torch.save(saved_dict, path) # Upload model to external logging service