-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
26e0c2c
commit 1081c00
Showing
19 changed files
with
11,372 additions
and
2,446 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Zero-Shot Learning of Causal Models (Cond-FiP) | ||
[![Static Badge](https://img.shields.io/badge/paper-CondFiP-brightgreen?style=plastic&label=Paper&labelColor=yellow) | ||
](https://arxiv.org/pdf/2410.06128) | ||
|
||
This repo implements Cond-FiP proposed in the paper "Zero-Shot Learning of Causal Models". | ||
|
||
Cond-FiP is a transformer-based approach to infer Structural Causal Models (SCMs) in a zero-shot manner. Rather than learning a specific SCM for each dataset, we enable the Fixed-Point Approach (FiP) proposed in [Scetbon et al. (2024)](https://openreview.net/pdf?id=JpzIGzru5F), to infer the generative SCMs conditionally on their empirical representations. More specifically, we propose to amortize the learning | ||
of a conditional version of FiP to infer generative SCMs from observations and causal structures on synthetically generated datasets. | ||
|
||
Cond-FiP is composed of two models: (1) a dataset Encoder that produces embeddings given the empirical representations of SCMs, and (2) a Decoder that conditionnally on the dataset embedding infers the generative functional mechanisms of the associated SCM. | ||
|
||
## Dependency | ||
We use [Poetry](https://python-poetry.org/) to manage the project dependencies, they are specified in [pyproject](pyproject.toml) file. To install poetry, run: | ||
|
||
```console | ||
curl -sSL https://install.python-poetry.org | python3 - | ||
``` | ||
To install the environment, run `poetry install` in the directory of cond_fip project. | ||
|
||
|
||
## Run experiments | ||
In the [launchers](src/cond_fip/launchers) directory, we provide scripts to run the training of both the encoder and decoder. | ||
|
||
|
||
### Amortized Learning of the Encoder | ||
To train the Encoder on the synthetically generated datasets of [AVICI](https://arxiv.org/abs/2205.12934), run the following command: | ||
```console | ||
python -m cond_fip.launchers.train_encoder | ||
``` | ||
The model as well as the config file will be saved in `src/cond_fip/outputs`. | ||
|
||
|
||
### Amortized Learning of Cond-FiP | ||
To train the Decoder on the synthetically generated datasets of [AVICI](https://arxiv.org/abs/2205.12934), run the following command: | ||
```console | ||
python -m cond_fip.launchers.train_cond_fip\ | ||
--run_id <name_of_the_directory_containing_the_trained_encoder_model> | ||
``` | ||
The model as well as the config file will be saved in `src/cond_fip/outputs`. This command assumes that an Encoder model has been trained and saved in a directory located at `src/cond_fip/outputs/<name_of_the_directory_containing_the_trained_encoder_model>`. | ||
|
||
### Test Cond-FiP on a new Dataset | ||
To test a trained Cond-FiP, we also provide a [launcher file](src/cond_fip/launchers/inference_cond_fip.py), that enables to infer SCMs with Cond-FiP on new datasets. | ||
|
||
To use this file, one needs to provide the path to the data in the [config file](src/cond_fip/config/numpy_tensor_data_module.yaml) by replacing the value of `data_dir`. | ||
The data should respect a specific format. One can generate example of datasets by running: | ||
|
||
```console | ||
python -m fip.data_generation.avici_data --func_type linear --graph_type er --noise_type gaussian --dist_case in --seed 1 --data_dim 5 --num_interventions 5 | ||
``` | ||
The data will be stored in `./data`. | ||
|
||
To test a pre-trained Cond-FiP model on a specific dataset, one simply needs to run: | ||
```console | ||
python -m cond_fip.launchers.inference_cond_fip\ | ||
--run_id <name_of_the_directory_containing_the_pre_trained_model>\ | ||
--path_data <path_to_the_data> | ||
``` | ||
|
||
This command assumes that a pre-trained Cond-FiP model has been saved in a directory located at `src/cond_fip/outputs/<name_of_the_directory_containing_the_pre_trained_model>`, and the data has been saved at the location `path_to_the_data`. | ||
|
||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
[tool.poetry] | ||
name = "cond_fip" | ||
version = "0.1.0" | ||
description = "Zero-Shot Learning of Causal Models" | ||
readme = "README.md" | ||
authors = ["Meyer Scetbon", "Divyat Mahajan"] | ||
packages = [ | ||
{ include = "cond_fip", from = "src" }, | ||
] | ||
license = "MIT" | ||
|
||
[tool.poetry.dependencies] | ||
python = "~3.10" | ||
fip = { path = "../fip"} | ||
|
||
[tool.poetry.group.dev.dependencies] | ||
black = {version="^22.6.0", extras=["jupyter"]} | ||
isort = "^5.10.1" | ||
jupyter = "^1.0.0" | ||
jupytext = "^1.13.8" | ||
mypy = "^1.0.0" | ||
pre-commit = "^2.19.0" | ||
pylint = "^2.14.4" | ||
pytest = "^7.1.2" | ||
pytest-cov = "^3.0.0" | ||
seaborn = "^0.12.2" | ||
types-python-dateutil = "^2.8.18" | ||
types-requests = "^2.31.0.10" | ||
ema-pytorch= "^0.6.0" | ||
|
||
|
||
[build-system] | ||
requires = ["poetry-core>=1.0.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.black] | ||
line-length = 120 | ||
|
||
[tool.isort] | ||
line_length = 120 | ||
profile = "black" | ||
py_version = 310 | ||
known_first_party = ["cond_fip"] | ||
|
||
# Keep import sorts by code jupytext percent block (https://github.com/PyCQA/isort/issues/1338) | ||
treat_comments_as_code = ["# %%"] | ||
|
||
[tool.pytest.ini_options] | ||
addopts = "--durations=200" | ||
junit_family = "xunit1" | ||
|
||
|
||
|
11 changes: 11 additions & 0 deletions
11
research_experiments/cond_fip/src/cond_fip/config/cond_fip_inference.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
seed_everything: 2048 | ||
|
||
model: | ||
class_path: cond_fip.tasks.cond_fip_inference.CondFiPInference | ||
init_args: | ||
enc_dec_model_path: ./src/cond_fip/outputs/amortized_enc_dec_training_2024-09-09_13-51-00/outputs/best_model.ckpt | ||
|
||
trainer: | ||
logger: MLFlowLogger | ||
accelerator: gpu | ||
devices: 1 |
63 changes: 63 additions & 0 deletions
63
research_experiments/cond_fip/src/cond_fip/config/cond_fip_training.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
seed_everything: 2048 | ||
|
||
model: | ||
class_path: cond_fip.tasks.cond_fip_training.CondFiPTraining | ||
init_args: | ||
encoder_model_path: ./src/cond_fip/outputs/amortized_encoder_training_2024-07-02_19-09-00/outputs/best_model.ckpt | ||
|
||
learning_rate: 1e-4 | ||
beta1: 0.9 | ||
beta2: 0.95 | ||
weight_decay: 1e-10 | ||
|
||
use_scheduler: true | ||
linear_warmup_steps: 1000 | ||
scheduler_steps: 10_000 | ||
|
||
d_model: 256 | ||
num_heads: 8 | ||
num_layers: 4 | ||
d_ff: 512 | ||
dropout: 0.1 | ||
dim_key: 64 | ||
num_layers_dataset: 2 | ||
|
||
distributed: false | ||
with_true_target: true | ||
final_pair_only: true | ||
|
||
with_ema: true | ||
ema_beta: 0.99 | ||
ema_update_every: 10 | ||
|
||
trainer: | ||
max_epochs: 7000 | ||
logger: MLFlowLogger | ||
accelerator: gpu | ||
check_val_every_n_epoch: 1 | ||
log_every_n_steps: 10 | ||
accumulate_grad_batches: 16 | ||
log_dir: "./src/cond_fip/logging_enc_dec/" | ||
inference_mode: false | ||
devices: 1 | ||
num_nodes: 1 | ||
|
||
early_stopping_callback: | ||
monitor: "val_loss" | ||
min_delta: 0.0001 | ||
patience: 500 | ||
verbose: False | ||
mode: "min" | ||
|
||
best_checkpoint_callback: | ||
dirpath: "./src/cond_fip/logging_enc_dec/" | ||
filename: "best_model" | ||
save_top_k: 1 | ||
mode: "min" | ||
monitor: "val_loss" | ||
every_n_epochs: 1 | ||
|
||
last_checkpoint_callback: | ||
save_last: true | ||
filename: "last_model" | ||
save_top_k: 0 # only the last checkpoint is saved |
59 changes: 59 additions & 0 deletions
59
research_experiments/cond_fip/src/cond_fip/config/encoder_training.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
seed_everything: 2048 | ||
|
||
model: | ||
class_path: cond_fip.tasks.encoder_training.EncoderTraining | ||
init_args: | ||
|
||
learning_rate: 1e-4 | ||
beta1: 0.9 | ||
beta2: 0.95 | ||
weight_decay: 5e-4 | ||
|
||
use_scheduler: true | ||
linear_warmup_steps: 1000 | ||
scheduler_steps: 10_000 | ||
|
||
d_model: 256 | ||
num_heads: 8 | ||
num_layers: 4 | ||
d_ff: 512 | ||
dropout: 0.0 | ||
dim_key: 32 | ||
d_hidden_head: 1024 | ||
|
||
distributed: false | ||
|
||
with_ema: true | ||
ema_beta: 0.99 | ||
ema_update_every: 10 | ||
|
||
trainer: | ||
max_epochs: 5000 | ||
logger: MLFlowLogger | ||
accelerator: gpu | ||
check_val_every_n_epoch: 1 | ||
log_every_n_steps: 10 | ||
log_dir: "./src/cond_fip/logging_enc/" | ||
inference_mode: false | ||
devices: 1 | ||
num_nodes: 1 | ||
|
||
early_stopping_callback: | ||
monitor: "val_loss" | ||
min_delta: 0.0001 | ||
patience: 500 | ||
verbose: False | ||
mode: "min" | ||
|
||
best_checkpoint_callback: | ||
dirpath: "./src/cond_fip/logging_enc/" | ||
filename: "best_model" | ||
save_top_k: 1 | ||
mode: "min" | ||
monitor: "val_loss" | ||
every_n_epochs: 1 | ||
|
||
last_checkpoint_callback: | ||
save_last: true | ||
filename: "last_model" | ||
save_top_k: 0 # only the last checkpoint is saved |
12 changes: 12 additions & 0 deletions
12
research_experiments/cond_fip/src/cond_fip/config/numpy_tensor_data_module.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
class_path: fip.data_modules.numpy_tensor_data_module.NumpyTensorDataModule | ||
init_args: | ||
data_dir : "./data/er_linear_gaussian_in/total_nodes_5/seed_1/" | ||
train_batch_size: 400 | ||
test_batch_size: 400 | ||
standardize: false | ||
with_true_graph: true | ||
split_data_noise: true | ||
dod: true | ||
num_workers: 23 | ||
shuffle: false | ||
num_interventions: 1 |
46 changes: 46 additions & 0 deletions
46
research_experiments/cond_fip/src/cond_fip/config/synthetic_data_module.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
class_path: fip.data_modules.synthetic_data_module.SyntheticDataModule | ||
init_args: | ||
sem_samplers: | ||
class_path: fip.data_generation.sem_factory.SemSamplerFactory | ||
init_args: | ||
node_nums: [20] | ||
noises: ['gaussian'] | ||
graphs: ['er', 'sf_in', 'sf_out'] | ||
funcs: ['linear', 'rff'] | ||
config_gaussian: | ||
low: 0.2 | ||
high: 2.0 | ||
config_er: | ||
edges_per_node: [1,2,3] | ||
config_sf: | ||
edges_per_node: [1,2,3] | ||
attach_power: [1.] | ||
config_linear: | ||
weight_low: 1. | ||
weight_high: 3. | ||
bias_low: -3. | ||
bias_high: 3. | ||
config_rff: | ||
num_rf: 100 | ||
length_low: 7. | ||
length_high: 10. | ||
out_low: 10. | ||
out_high: 20. | ||
bias_low: -3. | ||
bias_high: 3. | ||
train_batch_size: 4 | ||
test_batch_size: 4 | ||
sample_dataset_size: 400 | ||
standardize: true | ||
num_samples_used: 400 | ||
num_workers: 23 | ||
pin_memory: true | ||
persistent_workers: true | ||
prefetch_factor: 2 | ||
factor_epoch: 32 | ||
num_sems: 0 | ||
shuffle: true | ||
num_interventions: 2 | ||
num_intervention_samples: 100 | ||
proportion_treatment: 0. | ||
sample_counterfactuals: false |
20 changes: 20 additions & 0 deletions
20
research_experiments/cond_fip/src/cond_fip/entrypoint_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import pytorch_lightning as pl | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.cli import LightningCLI | ||
|
||
|
||
def main(): | ||
cli = LightningCLI( | ||
model_class=pl.LightningModule, | ||
datamodule_class=pl.LightningDataModule, | ||
trainer_class=Trainer, | ||
subclass_mode_data=True, | ||
subclass_mode_model=True, | ||
save_config_kwargs={"overwrite": True}, | ||
run=False, | ||
) | ||
cli.trainer.test(cli.model, datamodule=cli.datamodule) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.