Official code to reproduce the experiments for the paper World Models via Policy-Guided Trajectory Diffusion. PolyGRAD diffuses an initially random trajectory of states and actions into an on-policy trajectory, and uses the synthetic data for imagined on-policy RL training.
- Install MuJoCo 2.1.0 to
~/.mujoco/mujoco210
. - Install requirements and package.
conda create -n polygrad python=3.10
conda activate polygrad
cd polygrad-world-models
pip install -r requirements.txt
pip install -e .
Tested with Python 3.10.
To run online RL experiments:
python3 scripts/online_rl.py --config config.online_rl.hopper
Download the datasets from Google Drive or the Hugging Face Hub and store them in polygrad-world-models/datasets. Train and evaluate errors for PolyGRAD world models using:
python3 scripts/train_diffusion_wm.py --config config.world_model_only.polygrad_mlp.hopper_h10
Config files for the MLP and transformer denoising network, as well as different trajectory lengths are provided in the config folder.
In this repo, we also provide implementations of the autoregressive diffusion and transformer world model baselines in the paper. Ensure that you have the datasets in polygrad-world-models/datasets. Then, to train the autoregressive diffusion world model baseline:
python3 scripts/train_diffusion_wm.py --config config.world_model_only.autoregressive_diffusion.hopper
Note that the transformer world model baseline uses a different script:
python3 scripts/train_transformer_wm.py --config config.world_model_only.transformer_wm.hopper
For the MLP ensemble baseline we used the code from mbpo_pytorch. For the Dreamer-v3 baseline we used the dreamerv3-torch repo. Lastly, for the model-free RL baselines we used Stable Baselines 3.
@article{rigter2023world,
title={World Models via Policy-Guided Trajectory Diffusion},
author={Rigter, Marc and Yamada, Jun and Posner, Ingmar},
journal={Transactions on Machine Learning Research},
year={2024}
}
Our implementation utilises code from Diffuser, nanoGPT, and SynthER.
conda install -c menpo glfw3
Then add your conda environment include to CPATH (put this in your .bashrc to make it permanent): export CPATH=$CONDA_PREFIX/include Finally, install patchelf with `pip install patchelf`