Skip to content

Latest commit

 

History

History
67 lines (52 loc) · 3.12 KB

README.md

File metadata and controls

67 lines (52 loc) · 3.12 KB

World Models via Policy-Guided Trajectory Diffusion (PolyGRAD)

Official code to reproduce the experiments for the paper World Models via Policy-Guided Trajectory Diffusion. PolyGRAD diffuses an initially random trajectory of states and actions into an on-policy trajectory, and uses the synthetic data for imagined on-policy RL training.

Installation

  1. Install MuJoCo 2.1.0 to ~/.mujoco/mujoco210.
  2. Install requirements and package.
conda create -n polygrad python=3.10
conda activate polygrad
cd polygrad-world-models
pip install -r requirements.txt
pip install -e .

Tested with Python 3.10.

Running PolyGRAD

Online RL Experiments

To run online RL experiments:

python3 scripts/online_rl.py --config config.online_rl.hopper

Training World Model from Fixed Datasets

Download the datasets from Google Drive or the Hugging Face Hub and store them in polygrad-world-models/datasets. Train and evaluate errors for PolyGRAD world models using:

python3 scripts/train_diffusion_wm.py --config config.world_model_only.polygrad_mlp.hopper_h10

Config files for the MLP and transformer denoising network, as well as different trajectory lengths are provided in the config folder.

Running Baselines

In this repo, we also provide implementations of the autoregressive diffusion and transformer world model baselines in the paper. Ensure that you have the datasets in polygrad-world-models/datasets. Then, to train the autoregressive diffusion world model baseline:

python3 scripts/train_diffusion_wm.py --config config.world_model_only.autoregressive_diffusion.hopper

Note that the transformer world model baseline uses a different script:

python3 scripts/train_transformer_wm.py --config config.world_model_only.transformer_wm.hopper

For the MLP ensemble baseline we used the code from mbpo_pytorch. For the Dreamer-v3 baseline we used the dreamerv3-torch repo. Lastly, for the model-free RL baselines we used Stable Baselines 3.

Citing this work

@article{rigter2023world,
  title={World Models via Policy-Guided Trajectory Diffusion},
  author={Rigter, Marc and Yamada, Jun and Posner, Ingmar},
  journal={Transactions on Machine Learning Research},
  year={2024}
}

Acknowledgements

Our implementation utilises code from Diffuser, nanoGPT, and SynthER.

conda install -c menpo glfw3

Then add your conda environment include to CPATH (put this in your .bashrc to make it permanent):
export CPATH=$CONDA_PREFIX/include
Finally, install patchelf with `pip install patchelf`