From 42595a5de4e6e27d2b3f83c143e3360be03e36ef Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 13 Aug 2024 16:01:29 -0400 Subject: [PATCH] Fix warning when loading a `RecurrentPPO` model (#255) * Reformat configs * Fix warning when loading RecurrentPPO agent --- .github/workflows/ci.yml | 76 +++++++++++----------- docs/misc/changelog.rst | 3 +- pyproject.toml | 16 ++--- sb3_contrib/common/maskable/policies.py | 2 +- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 5 +- sb3_contrib/version.txt | 2 +- 6 files changed, 54 insertions(+), 50 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6f81953f..ea1a3c84 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,9 +5,9 @@ name: CI on: push: - branches: [ master ] + branches: [master] pull_request: - branches: [ master ] + branches: [master] jobs: build: @@ -22,42 +22,42 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - # cpu version of pytorch - pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cpu + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # cpu version of pytorch + pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cpu - # Install Atari Roms - pip install autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz + # Install Atari Roms + pip install autorom + wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 + base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz + AutoROM --accept-license --source-file Roms.tar.gz - # Install master version - # and dependencies for docs and tests - pip install "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3" - pip install . - # Use headless version - pip install opencv-python-headless + # Install master version + # and dependencies for docs and tests + pip install "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3" + pip install . + # Use headless version + pip install opencv-python-headless - - name: Lint with ruff - run: | - make lint - - name: Check codestyle - run: | - make check-codestyle - - name: Build the doc - run: | - make doc - - name: Type check - run: | - make type - - name: Test with pytest - run: | - make pytest + - name: Lint with ruff + run: | + make lint + - name: Check codestyle + run: | + make check-codestyle + - name: Build the doc + run: | + make doc + - name: Type check + run: | + make type + - name: Test with pytest + run: | + make pytest diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 63776e6e..d63060c7 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 2.4.0a4 (WIP) +Release 2.4.0a8 (WIP) -------------------------- Breaking Changes: @@ -18,6 +18,7 @@ Bug Fixes: ^^^^^^^^^^ - Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger) - Updated QR-DQN paper link in docs (@corentinlger) +- Fixed a warning with PyTorch 2.4 when loading a `RecurrentPPO` model (You are using torch.load with weights_only=False) Deprecations: ^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 7aaad514..50f3dfce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ ignore = ["B028", "RUF013"] [tool.ruff.lint.per-file-ignores] # ClassVar, implicit optional check not needed for tests -"./tests/*.py"= ["RUF012", "RUF013"] +"./tests/*.py" = ["RUF012", "RUF013"] [tool.ruff.lint.mccabe] # Unlike Flake8, ruff default to a complexity level of 10. @@ -35,17 +35,13 @@ exclude = """(?x)( [tool.pytest.ini_options] # Deterministic ordering for tests; useful for pytest-xdist. -env = [ - "PYTHONHASHSEED=0" -] +env = ["PYTHONHASHSEED=0"] filterwarnings = [ # Tensorboard warnings "ignore::DeprecationWarning:tensorboard", ] -markers = [ - "slow: marks tests as slow (deselect with '-m \"not slow\"')" -] +markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] [tool.coverage.run] disable_warnings = ["couldnt-parse"] @@ -53,4 +49,8 @@ branch = false omit = ["tests/*", "setup.py"] [tool.coverage.report] -exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"] +exclude_lines = [ + "pragma: no cover", + "raise NotImplementedError()", + "if typing.TYPE_CHECKING:", +] diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index 77dfde38..1a0d53aa 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -304,7 +304,7 @@ def predict( with th.no_grad(): actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks) # Convert to numpy - actions = actions.cpu().numpy() + actions = actions.cpu().numpy() # type: ignore[assignment] if isinstance(self.action_space, spaces.Box): if self.squash_output: diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 7cd97cf3..5b976939 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union import numpy as np import torch as th @@ -455,3 +455,6 @@ def learn( reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, ) + + def _excluded_save_params(self) -> List[str]: + return super()._excluded_save_params() + ["_last_lstm_states"] # noqa: RUF005 diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 2d22b158..ee717ba1 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.4.0a4 +2.4.0a8