Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cond-fip #114

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5,270 changes: 2,825 additions & 2,445 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "causica"
version = "0.4.4"
version = "0.4.5"
description = ""
readme = "README.md"
authors = ["Microsoft Research - Causica"]
Expand Down
61 changes: 61 additions & 0 deletions research_experiments/cond_fip/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Zero-Shot Learning of Causal Models (Cond-FiP)
[![Static Badge](https://img.shields.io/badge/paper-CondFiP-brightgreen?style=plastic&label=Paper&labelColor=yellow)
](https://arxiv.org/pdf/2410.06128)

This repo implements Cond-FiP proposed in the paper "Zero-Shot Learning of Causal Models".

Cond-FiP is a transformer-based approach to infer Structural Causal Models (SCMs) in a zero-shot manner. Rather than learning a specific SCM for each dataset, we enable the Fixed-Point Approach (FiP) proposed in [Scetbon et al. (2024)](https://openreview.net/pdf?id=JpzIGzru5F), to infer the generative SCMs conditionally on their empirical representations. More specifically, we propose to amortize the learning
of a conditional version of FiP to infer generative SCMs from observations and causal structures on synthetically generated datasets.

Cond-FiP is composed of two models: (1) a dataset Encoder that produces embeddings given the empirical representations of SCMs, and (2) a Decoder that conditionnally on the dataset embedding infers the generative functional mechanisms of the associated SCM.

## Dependency
We use [Poetry](https://python-poetry.org/) to manage the project dependencies, they are specified in [pyproject](pyproject.toml) file. To install poetry, run:

```console
curl -sSL https://install.python-poetry.org | python3 -
```
To install the environment, run `poetry install` in the directory of cond_fip project.


## Run experiments
In the [launchers](src/cond_fip/launchers) directory, we provide scripts to run the training of both the encoder and decoder.


### Amortized Learning of the Encoder
To train the Encoder on the synthetically generated datasets of [AVICI](https://arxiv.org/abs/2205.12934), run the following command:
```console
python -m cond_fip.launchers.train_encoder
```
The model as well as the config file will be saved in `src/cond_fip/outputs`.


### Amortized Learning of Cond-FiP
To train the Decoder on the synthetically generated datasets of [AVICI](https://arxiv.org/abs/2205.12934), run the following command:
```console
python -m cond_fip.launchers.train_cond_fip\
--run_id <name_of_the_directory_containing_the_trained_encoder_model>
```
The model as well as the config file will be saved in `src/cond_fip/outputs`. This command assumes that an Encoder model has been trained and saved in a directory located at `src/cond_fip/outputs/<name_of_the_directory_containing_the_trained_encoder_model>`.

### Test Cond-FiP on a new Dataset
To test a trained Cond-FiP, we also provide a [launcher file](src/cond_fip/launchers/inference_cond_fip.py), that enables to infer SCMs with Cond-FiP on new datasets.

To use this file, one needs to provide the path to the data in the [config file](src/cond_fip/config/numpy_tensor_data_module.yaml) by replacing the value of `data_dir`.
The data should respect a specific format. One can generate example of datasets by running:

```console
python -m fip.data_generation.avici_data --func_type linear --graph_type er --noise_type gaussian --dist_case in --seed 1 --data_dim 5 --num_interventions 5
```
The data will be stored in `./data`.

To test a pre-trained Cond-FiP model on a specific dataset, one simply needs to run:
```console
python -m cond_fip.launchers.inference_cond_fip\
--run_id <name_of_the_directory_containing_the_pre_trained_model>\
--path_data <path_to_the_data>
```

This command assumes that a pre-trained Cond-FiP model has been saved in a directory located at `src/cond_fip/outputs/<name_of_the_directory_containing_the_pre_trained_model>`, and the data has been saved at the location `path_to_the_data`.


6,270 changes: 6,270 additions & 0 deletions research_experiments/cond_fip/poetry.lock

Large diffs are not rendered by default.

53 changes: 53 additions & 0 deletions research_experiments/cond_fip/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
[tool.poetry]
name = "cond_fip"
version = "0.1.0"
description = "Zero-Shot Learning of Causal Models"
readme = "README.md"
authors = ["Meyer Scetbon", "Divyat Mahajan"]
packages = [
{ include = "cond_fip", from = "src" },
]
license = "MIT"

[tool.poetry.dependencies]
python = "~3.10"
fip = { path = "../fip"}

[tool.poetry.group.dev.dependencies]
black = {version="^22.6.0", extras=["jupyter"]}
isort = "^5.10.1"
jupyter = "^1.0.0"
jupytext = "^1.13.8"
mypy = "^1.0.0"
pre-commit = "^2.19.0"
pylint = "^2.14.4"
pytest = "^7.1.2"
pytest-cov = "^3.0.0"
seaborn = "^0.12.2"
types-python-dateutil = "^2.8.18"
types-requests = "^2.31.0.10"
ema-pytorch= "^0.6.0"


[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.black]
line-length = 120

[tool.isort]
line_length = 120
profile = "black"
py_version = 310
known_first_party = ["cond_fip"]

# Keep import sorts by code jupytext percent block (https://github.com/PyCQA/isort/issues/1338)
treat_comments_as_code = ["# %%"]

[tool.pytest.ini_options]
addopts = "--durations=200"
junit_family = "xunit1"



Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
seed_everything: 2048

model:
class_path: cond_fip.tasks.cond_fip_inference.CondFiPInference
init_args:
enc_dec_model_path: ./src/cond_fip/outputs/amortized_enc_dec_training_2024-09-09_13-51-00/outputs/best_model.ckpt

trainer:
logger: MLFlowLogger
accelerator: gpu
devices: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
seed_everything: 2048

model:
class_path: cond_fip.tasks.cond_fip_training.CondFiPTraining
init_args:
encoder_model_path: ./src/cond_fip/outputs/amortized_encoder_training_2024-07-02_19-09-00/outputs/best_model.ckpt

learning_rate: 1e-4
beta1: 0.9
beta2: 0.95
weight_decay: 1e-10

use_scheduler: true
linear_warmup_steps: 1000
scheduler_steps: 10_000

d_model: 256
num_heads: 8
num_layers: 4
d_ff: 512
dropout: 0.1
dim_key: 64
num_layers_dataset: 2

distributed: false
with_true_target: true
final_pair_only: true

with_ema: true
ema_beta: 0.99
ema_update_every: 10

trainer:
max_epochs: 7000
logger: MLFlowLogger
accelerator: gpu
check_val_every_n_epoch: 1
log_every_n_steps: 10
accumulate_grad_batches: 16
log_dir: "./src/cond_fip/logging_enc_dec/"
inference_mode: false
devices: 1
num_nodes: 1

early_stopping_callback:
monitor: "val_loss"
min_delta: 0.0001
patience: 500
verbose: False
mode: "min"

best_checkpoint_callback:
dirpath: "./src/cond_fip/logging_enc_dec/"
filename: "best_model"
save_top_k: 1
mode: "min"
monitor: "val_loss"
every_n_epochs: 1

last_checkpoint_callback:
save_last: true
filename: "last_model"
save_top_k: 0 # only the last checkpoint is saved
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
seed_everything: 2048

model:
class_path: cond_fip.tasks.encoder_training.EncoderTraining
init_args:

learning_rate: 1e-4
beta1: 0.9
beta2: 0.95
weight_decay: 5e-4

use_scheduler: true
linear_warmup_steps: 1000
scheduler_steps: 10_000

d_model: 256
num_heads: 8
num_layers: 4
d_ff: 512
dropout: 0.0
dim_key: 32
d_hidden_head: 1024

distributed: false

with_ema: true
ema_beta: 0.99
ema_update_every: 10

trainer:
max_epochs: 5000
logger: MLFlowLogger
accelerator: gpu
check_val_every_n_epoch: 1
log_every_n_steps: 10
log_dir: "./src/cond_fip/logging_enc/"
inference_mode: false
devices: 1
num_nodes: 1

early_stopping_callback:
monitor: "val_loss"
min_delta: 0.0001
patience: 500
verbose: False
mode: "min"

best_checkpoint_callback:
dirpath: "./src/cond_fip/logging_enc/"
filename: "best_model"
save_top_k: 1
mode: "min"
monitor: "val_loss"
every_n_epochs: 1

last_checkpoint_callback:
save_last: true
filename: "last_model"
save_top_k: 0 # only the last checkpoint is saved
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class_path: fip.data_modules.numpy_tensor_data_module.NumpyTensorDataModule
init_args:
data_dir : "./data/er_linear_gaussian_in/total_nodes_5/seed_1/"
train_batch_size: 400
test_batch_size: 400
standardize: false
with_true_graph: true
split_data_noise: true
dod: true
num_workers: 23
shuffle: false
num_interventions: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
class_path: fip.data_modules.synthetic_data_module.SyntheticDataModule
init_args:
sem_samplers:
class_path: fip.data_generation.sem_factory.SemSamplerFactory
init_args:
node_nums: [20]
noises: ['gaussian']
graphs: ['er', 'sf_in', 'sf_out']
funcs: ['linear', 'rff']
config_gaussian:
low: 0.2
high: 2.0
config_er:
edges_per_node: [1,2,3]
config_sf:
edges_per_node: [1,2,3]
attach_power: [1.]
config_linear:
weight_low: 1.
weight_high: 3.
bias_low: -3.
bias_high: 3.
config_rff:
num_rf: 100
length_low: 7.
length_high: 10.
out_low: 10.
out_high: 20.
bias_low: -3.
bias_high: 3.
train_batch_size: 4
test_batch_size: 4
sample_dataset_size: 400
standardize: true
num_samples_used: 400
num_workers: 23
pin_memory: true
persistent_workers: true
prefetch_factor: 2
factor_epoch: 32
num_sems: 0
shuffle: true
num_interventions: 2
num_intervention_samples: 100
proportion_treatment: 0.
sample_counterfactuals: false
20 changes: 20 additions & 0 deletions research_experiments/cond_fip/src/cond_fip/entrypoint_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.cli import LightningCLI


def main():
cli = LightningCLI(
model_class=pl.LightningModule,
datamodule_class=pl.LightningDataModule,
trainer_class=Trainer,
subclass_mode_data=True,
subclass_mode_model=True,
save_config_kwargs={"overwrite": True},
run=False,
)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == "__main__":
main()
Loading
Loading