Skip to content

Stable-Baselines3 v2.3.0: New defaults hyperparameters for DDPG, TD3 and DQN

Compare
Choose a tag to compare
@araffin araffin released this 31 Mar 18:33
· 42 commits to master since this release
429be93

Warning

Because of weights_only=True, this release breaks loading of policies when using PyTorch 1.13.
Please upgrade to PyTorch >= 2.0 or upgrade SB3 version (we reverted the change in SB3 2.3.2)

SB3 Contrib (more algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
RL Zoo3 (training framework): https://github.com/DLR-RM/rl-baselines3-zoo
Stable-Baselines Jax (SBX): https://github.com/araffin/sbx

To upgrade:

pip install stable_baselines3 sb3_contrib --upgrade

or simply (rl zoo depends on SB3 and SB3 contrib):

pip install rl_zoo3 --upgrade

Breaking Changes:

  • The defaults hyperparameters of TD3 and DDPG have been changed to be more consistent with SAC
  # SB3 < 2.3.0 default hyperparameters
  # model = TD3("MlpPolicy", env, train_freq=(1, "episode"), gradient_steps=-1, batch_size=100)
  # SB3 >= 2.3.0:
  model = TD3("MlpPolicy", env, train_freq=1, gradient_steps=1, batch_size=256)

Note

Two inconsistencies remain: the default network architecture for TD3/DDPG is [400, 300] instead of [256, 256] for SAC (for backward compatibility reasons, see report on the influence of the network size ) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see W&B report on the influence of the lr )

  • The default learning_starts parameter of DQN have been changed to be consistent with the other offpolicy algorithms
  # SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters
  # model = DQN("MlpPolicy", env, learning_starts=50_000)
  # SB3 >= 2.3.0:
  model = DQN("MlpPolicy", env, learning_starts=100)
  • For safety, torch.load() is now called with weights_only=True when loading torch tensors,
    policy load() still uses weights_only=False as gymnasium imports are required for it to work
  • When using huggingface_sb3, you will now need to set TRUST_REMOTE_CODE=True when downloading models from the hub, as pickle.load is not safe.

New Features:

  • Log success rate rollout/success_rate when available for on policy algorithms (@corentinlger)

Bug Fixes:

  • Fixed monitor_wrapper argument that was not passed to the parent class, and dones argument that wasn't passed to _update_into_buffer (@corentinlger)

SB3-Contrib

  • Added rollout_buffer_class and rollout_buffer_kwargs arguments to MaskablePPO
  • Fixed train_freq type annotation for tqc and qrdqn (@Armandpl)
  • Fixed sb3_contrib/common/maskable/*.py type annotations
  • Fixed sb3_contrib/ppo_mask/ppo_mask.py type annotations
  • Fixed sb3_contrib/common/vec_env/async_eval.py type annotations
  • Add some additional notes about MaskablePPO (evaluation and multi-process) (@icheered)

RL Zoo

  • Updated defaults hyperparameters for TD3/DDPG to be more consistent with SAC
  • Upgraded MuJoCo envs hyperparameters to v4 (pre-trained agents need to be updated)
  • Added test dependencies to setup.py (@power-edge)
  • Simplify dependencies of requirements.txt (remove duplicates from setup.py)

SBX (SB3 + Jax)

  • Added support for MultiDiscrete and MultiBinary action spaces to PPO
  • Added support for large values for gradient_steps to SAC, TD3, and TQC
  • Fix train() signature and update type hints
  • Fix replay buffer device at load time
  • Added flatten layer
  • Added CrossQ

Others:

  • Updated black from v23 to v24
  • Updated ruff to >= v0.3.1
  • Updated env checker for (multi)discrete spaces with non-zero start.

Documentation:

  • Added a paragraph on modifying vectorized environment parameters via setters (@fracapuano)
  • Updated callback code example
  • Updated export to ONNX documentation, it is now much simpler to export SB3 models with newer ONNX Opset!
  • Added video link to "Practical Tips for Reliable Reinforcement Learning" video
  • Added render_mode="human" in the README example (@marekm4)
  • Fixed docstring signature for sum_independent_dims (@StagOverflow)
  • Updated docstring description for log_interval in the base class (@rushitnshah).

Full Changelog: v2.2.1...v2.3.0