Skip to content

Commit

Permalink
Release 0.3.6 (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
confoundry authored Nov 8, 2023
1 parent e97c5c8 commit f20772b
Show file tree
Hide file tree
Showing 21 changed files with 1,701 additions and 864 deletions.
9 changes: 2 additions & 7 deletions .github/workflows/pypi-publish.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
on:
push:
branches: [ "main" ]
paths-ignore:
- ".github/workflows/*"
- ".devcontainer/*"
- ".gitignore"
- ".pre-commit-config.yaml"
release:
types: [released]
jobs:
pypi-publish:
name: Upload release to PyPI
Expand Down
1,435 changes: 678 additions & 757 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "causica"
version = "0.3.5"
version = "0.3.6"
description = ""
readme = "README.md"
authors = []
Expand Down
19 changes: 18 additions & 1 deletion src/causica/data_generation/samplers/noise_dist_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from causica.distributions import JointNoiseModule
from causica.distributions.noise import NoiseModule, UnivariateNormalNoiseModule
from causica.distributions.noise.bernoulli import BernoulliNoiseModule
from causica.distributions.noise.categorical import CategoricalNoiseModule


class NoiseModuleSampler(Sampler[NoiseModule]):
Expand Down Expand Up @@ -64,5 +65,21 @@ def __init__(self, base_logits_dist: td.Distribution, dim: int = 1):
def sample(
self,
) -> NoiseModule:
base_logits = self.base_logits_dist.sample().item()
base_logits = self.base_logits_dist.sample()
return BernoulliNoiseModule(dim=self.dim, init_base_logits=base_logits)


class CategoricalNoiseModuleSampler(NoiseModuleSampler):
"""Sample a CategoricalNoiseModule, with num_classes classes. This does not actually sample but returns the noise."""

def __init__(self, base_logits_dist: td.Distribution | None, num_classes: int = 2):
super().__init__()
assert num_classes >= 2
self.num_classes = num_classes
self.base_logits_dist = base_logits_dist

def sample(
self,
) -> NoiseModule:
init_base_logits = self.base_logits_dist.sample() if self.base_logits_dist else None
return CategoricalNoiseModule(num_classes=self.num_classes, init_base_logits=init_base_logits)
16 changes: 16 additions & 0 deletions src/causica/datasets/causica_dataset_format/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from causica.datasets.causica_dataset_format.load import (
CAUSICA_DATASETS_PATH,
CounterfactualWithEffects,
DataEnum,
InterventionWithEffects,
Variable,
VariablesMetadata,
get_group_idxs,
get_group_names,
get_group_variable_names,
get_name_to_idx,
load_data,
tensordict_from_variables_metadata,
tensordict_to_tensor,
)
from causica.datasets.causica_dataset_format.save import save_data, save_dataset
Loading

0 comments on commit f20772b

Please sign in to comment.