Skip to content

Commit

Permalink
Release 0.3.2 (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
confoundry authored Apr 28, 2023
1 parent 83f2069 commit 8e9cc2d
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 75 deletions.
138 changes: 69 additions & 69 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "causica"
version = "0.3.1"
version = "0.3.2"
description = ""
readme = "README.md"
authors = []
Expand All @@ -18,6 +18,7 @@ numpy = "^1.22.4"
pandas = "^1.4.2"
tensorboard = "^2.9.0"
pytorch-lightning = {version = "^1.9.0", extras= ["extra"]}
jsonargparse = "<4.21.0" # 4.21.0 breaks lightning cli
dataclasses-json = "^0.5.7"
types-PyYAML = "^6.0.12.2"
tensordict = "^0.1.0"
Expand Down
6 changes: 3 additions & 3 deletions src/causica/distributions/adjacency/directed_acyclic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ def mode(self) -> torch.Tensor:
We return the mode corresponding to the "default" ordering.
There are 2 possibilities:
p >= 0.5: A lower triangular matrix of ones
p < 0.5: A matrix of zeros
p > 0.5: A lower triangular matrix of ones
p <= 0.5: A matrix of zeros
Returns:
A tensor of shape batch_shape + (num_nodes, num_nodes)
"""
return fill_triangular(self.bern_dist.mode)
return fill_triangular(torch.nan_to_num(self.bern_dist.mode, nan=0.0))

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
5 changes: 4 additions & 1 deletion src/causica/distributions/adjacency/enco.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ def mode(self) -> torch.Tensor:
A tensor of shape batch_shape + (num_nodes, num_nodes)
"""
logits = self._get_independent_bernoulli_logits()
return self.base_dist(logits).mode * (1.0 - torch.eye(self.num_nodes, device=logits.device))
# bernoulli mode can be nan for very small logits, favour sparseness and set to 0
return torch.nan_to_num(self.base_dist(logits).mode, nan=0.0) * (
1.0 - torch.eye(self.num_nodes, device=logits.device)
)

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Expand Down
3 changes: 2 additions & 1 deletion src/causica/distributions/adjacency/three_way.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def mode(self) -> torch.Tensor:
Returns:
A tensor of shape batch_shape + (num_nodes, num_nodes)
"""
return _triangular_vec_to_matrix(self.base_dist(self.logits).mode)
# bernoulli mode can be nan for very small logits, favour sparseness and set to 0
return _triangular_vec_to_matrix(torch.nan_to_num(self.base_dist(self.logits).mode, 0.0))

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Expand Down
9 changes: 9 additions & 0 deletions src/causica/distributions/noise/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def noise_to_sample(self, noise: torch.Tensor) -> torch.Tensor:
"""
return ((self.delta_logits + noise) > 0).float()

@property
def mode(self):
"""
Override the default `mode` method to prevent it returning nan's.
We favour sparseness, so if logit == 0, set the mode to be zero.
"""
return (self.logits > 0).to(self.logits)


class BernoulliNoiseModule(NoiseModule[IndependentNoise[BernoulliNoise]]):
"""Represents a BernoulliNoise distribution with learnable logits."""
Expand Down

0 comments on commit 8e9cc2d

Please sign in to comment.