diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cec72857..d05cb154 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a6 (WIP) +Release 2.4.0a8 (WIP) -------------------------- Breaking Changes: @@ -19,6 +19,7 @@ Bug Fixes: ^^^^^^^^^^ - Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger) - Updated QR-DQN paper link in docs (@corentinlger) +- Fixed a warning with PyTorch 2.4 when loading a `RecurrentPPO` model (You are using torch.load with weights_only=False) Deprecations: ^^^^^^^^^^^^^ diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index 77dfde38..1a0d53aa 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -304,7 +304,7 @@ def predict( with th.no_grad(): actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks) # Convert to numpy - actions = actions.cpu().numpy() + actions = actions.cpu().numpy() # type: ignore[assignment] if isinstance(self.action_space, spaces.Box): if self.squash_output: diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 7cd97cf3..5b976939 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union import numpy as np import torch as th @@ -455,3 +455,6 @@ def learn( reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, ) + + def _excluded_save_params(self) -> List[str]: + return super()._excluded_save_params() + ["_last_lstm_states"] # noqa: RUF005 diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 464a5c4d..433f7df2 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.4.0a6 +2.4.0a8 \ No newline at end of file