Skip to content

Commit

Permalink
Merge branch 'master' into feat/crossq
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Aug 13, 2024
2 parents 03db09e + 42595a5 commit 244b930
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.4.0a6 (WIP)
Release 2.4.0a8 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -19,6 +19,7 @@ Bug Fixes:
^^^^^^^^^^
- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
- Updated QR-DQN paper link in docs (@corentinlger)
- Fixed a warning with PyTorch 2.4 when loading a `RecurrentPPO` model (You are using torch.load with weights_only=False)

Deprecations:
^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/common/maskable/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def predict(
with th.no_grad():
actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks)
# Convert to numpy
actions = actions.cpu().numpy()
actions = actions.cpu().numpy() # type: ignore[assignment]

if isinstance(self.action_space, spaces.Box):
if self.squash_output:
Expand Down
5 changes: 4 additions & 1 deletion sb3_contrib/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union

import numpy as np
import torch as th
Expand Down Expand Up @@ -455,3 +455,6 @@ def learn(
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)

def _excluded_save_params(self) -> List[str]:
return super()._excluded_save_params() + ["_last_lstm_states"] # noqa: RUF005
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a6
2.4.0a8

0 comments on commit 244b930

Please sign in to comment.