This repository contains the evaluation and interpretability code for the paper "Planning behavior in a recurrent neural network that plays Sokoban". (OpenReview-ICML-MI-Workshop) (arXiv)
The lp-training repository lets you train the neural networks on Sokoban. If you just want to train the DRC networks, you should go there.
The repository can be installed with pip:
pip install -e .
We also provide a dockerfile for running the code:
docker build -t learned-planner .
docker run -it learned-planner
We install jax[cpu]
by default. JAX is only used to obtain the cache in the plot/behavior_analysis.py
script. To run the script on a GPU, you can install JAX on CUDA:
pip uninstall jax
pip install jax[cuda]
We implemented a faster version of the Sokoban environment in C++ using the Envpool library. As per our testing, Envpool only works on Linux as of now. We provide the python wheels for the library in the Envpool repository:
pip install https://github.com/AlignmentResearch/envpool/releases/download/v0.2.0/envpool-0.8.4-cp310-cp310-linux_x86_64.whl
To build the envpool library from source, follow the instructions in the original documentation using our forked envpool version.
The trained DRC networks are available in our huggingface model hub which contains all the checkpoints for the ResNet
, DRC(1, 1)
, and DRC(3, 3)
models trained with different hyperparameters. The best model for each of the model types are available at:
- DRC(3, 3): drc33/bkynosqi/cp_2002944000
- DRC(1, 1): drc11/eue6pax7/cp_2002944000
- ResNet: resnet/syb50iz7/cp_2002944000
Probes and SAEs trained on the DRC(3, 3) model are available at the same huggingface model hub under the probes
and saes
directories.
First, you will need to clone the Boxoban levels. We assume that the levels are stored in the training/.sokoban_cache
directory. If you want to change the path to the directory, you can set a new path in the learned_planner/__init__.py
file. You can clone the levels using the following commands:
BOXOBAN_CACHE="training/.sokoban_cache" # change if desired
mkdir -p "$BOXOBAN_CACHE"
git clone https://github.com/google-deepmind/boxoban-levels \
"$BOXOBAN_CACHE/boxoban-levels-master"
You can load the model using the following code:
import pathlib
import os
from cleanba import cleanba_impala
from learned_planner.policies import download_policy_from_huggingface
from learned_planner.interp.utils import get_boxoban_cfg
MODEL_PATH_IN_REPO = "drc33/bkynosqi/cp_2002944000/" # DRC(3, 3) 2B checkpoint
MODEL_PATH = download_policy_from_huggingface(MODEL_PATH_IN_REPO)
env_cfg = get_boxoban_cfg().make()
jax_policy, carry_t, jax_args, train_state, _ = cleanba_impala.load_train_state(MODEL_PATH, env_cfg)
The jax_policy
loads the network using the JAX implementation of the DRC network in the lp-training repository.
This repository provides the PyTorch implementation of the DRC network compatible with MambaLens for doing interpretability research. You can load the model using the following code:
from learned_planner.interp.utils import load_jax_model_to_torch
cfg_th, policy_th = load_jax_model_to_torch(MODEL_PATH, env_cfg)
The behavioral results from the paper can be reproduced using the behavior_analysis.py
script:
python plot/behavior_analysis.py
The script uses CPU or GPU depending on the type of JAX library installed. See the installation section for more details. This script will generate the plots in the {output_base_path}/{model_name}/plots
directory.
The A* solutions for the Boxoban levels can be found here.
For training the probes, we first need to generate the dataset of model activations. The learned_planner/interp/collect_dataset.py
script can be used to cache the activations. Activations of each level are stored in a separate pickle object of the class learned_planner.interp.collect_dataset.DatasetStore
. The DatasetStore
The script uses the DRC(3, 3) model by default. See the script for additional options.
python learned_planner/interp/collect_dataset.py --boxoban_cache {BOXOBAN_CACHE} --output_path {activation_cache_path}
The learned_planner/interp/save_ds.py
script takes the activations path using --dataset_path
and --labels_type
to create the dataset for the probe with the specified type and saves the torch dataset learned_planner.interp.train_probes.ActivationsDataset
in the same dataset path. See the script for additional options.
python learned_planner/interp/save_ds.py --dataset_path {activation_cache_path} --labels_type {labels_type}
The files provided in experiments/probes/
defines the hyperparameter search space for different probes. Running the files will train a probe with each hyperparameter configuration. The plot/interp/probes/probe_hp_search.py
script can be used to plot the results of the hyperparameter search and pick the best probe on the validation set. The scripts in experiments directory run the appropriate shell command to train the probes. Alternatively, you can directly train the probes using the command below. The default config is available in learned_planners/configs/train_probe.py
. You can overwrite arguments in the config using the cmd.{argument}={value}
syntax.
WANDB_MODE=disabled python -m learned_planners --from-py-fn=learned_planners.configs.train_probe:train_local cmd.train_on.layer={layer} cmd.train_on.dataset_name={dataset_name} cmd.dataset_path={dataset_path}
The files provided in experiments/sae/
defines the hyperparameter search space for training the SAEs.
The plot/interp/probes/
directory contains the scripts to evaluate the different types of probes in multiple ways. These are main scripts used to evaluate the results in the paper:
evaluate_probe
: evaluates the precision, recall, and F1 scores of the probes on a dataset.ci_score_direction_probe
: computes the causal intervention score for box or agent direction probes by modifying one single direction (move) in the plan using the probe and checking whether the agent follows the modified plan.ci_score_box_target_probe
: computes the causal intervention score for next_box or next_target probes.ci_score_from_csv
: The above scripts save the results in a CSV file. This script can be used to compute the average and best case CI scores from the CSV files.measure_plan_quality
: computes the plan quality of the boxes directions probe across thinking steps.measure_plan_recall
: computes the plan recall of the boxes directions probe across thinking steps.
The plot/interp/evaluate_features.py
script can be used to find interpretable features in channels, SAE feature neurons. Probes can also be evaluated in this script.
The plot/interp/save_{probe/sae}_videos.py
script can be used to save the videos of probes / SAEs features.
If you use this code, please cite our work:
@inproceedings{garriga-alonso2024planning,
title={Planning behavior in a recurrent neural network that plays Sokoban},
author={Adri{\`a} Garriga-Alonso and Mohammad Taufeeque and Adam Gleave},
booktitle={ICML 2024 Workshop on Mechanistic Interpretability},
year={2024},
url={https://openreview.net/forum?id=T9sB3S2hok}
}