Skip to content

Commit

Permalink
V0.4 (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
LouisDesdoigts authored May 7, 2023
1 parent 720147c commit 8aec3e0
Show file tree
Hide file tree
Showing 24 changed files with 1,012 additions and 675 deletions.
14 changes: 9 additions & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ jobs:
# ===== Tests =====
- name: base tests
run: pytest --quiet tests/test_base.py
- name: equinox tests
run: pytest --quiet tests/test_equinox.py
- name: equinox tree
- name: tree tests
run: pytest --quiet tests/test_tree.py
- name: equinox optimisation
- name: bayes tests
run: pytest --quiet tests/test_bayes.py
- name: equinox tests
run: pytest --quiet tests/test_eqx.py
- name: optimisation tests
run: pytest --quiet tests/test_optimisation.py
- name: equinox serialisation
- name: jit tests
run: pytest --quiet tests/test_jit.py
- name: serialisation serialisation
run: pytest --quiet tests/test_serialisation.py
9 changes: 0 additions & 9 deletions docs/API/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ value = pytree.get(paths)

```python
pytree = pytree.set(paths, values)
pytree = pytree.set_and_call(paths, values, call_fn)
```

**Arithmetic Methods**
Expand All @@ -26,13 +25,5 @@ pytree = pytree.min(paths, values)
pytree = pytree.max(paths, values)
```

**Functional Methods**

```python
pytree = pytree.apply(paths, fns)
pytree = pytree.apply_args(paths, fns, args)
pytree = pytree.apply_and_call(paths, fns, call_fn)
```

!!! info "Full API"
::: zodiax.base.Base
8 changes: 8 additions & 0 deletions docs/API/bayes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Bayes

The bayesian module is designed to ease the calculation of things like convariance and fisher matrices in differentiable ways. It implements two likelihood functions `poiss_loglike` and `chi2_loglike`. They both take in a pytree and data and return a scalar log likelihood. The `poiss_loglike` function assumes the data is poisson distributed and the `chi2_loglike` function assumes the data is normally distributed. To use these functions the input pytree _must_ have a `.model()` function.

There are also four functions used to calcualte fisher and covariances matrices: `fisher_matrix`, `covariance_matrix`, `self_fisher_matrix`, `self_covariance_matrix`. The `fisher_matrix` and `covariance_matrix` functions take in a pytree, parameters, a log likelihood function and data. They return the fisher and covariance matrices respectively. The `self_fisher_matrix` and `self_covariance_matrix` functions take in a pytree, parameters and a log likelihood function. They return the fisher and covariance matrices respectively, but the data is generated from the model itself.

!!! info "Full API"
::: zodiax.bayes
8 changes: 6 additions & 2 deletions docs/API/equinox.md → docs/API/eqx.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ Submodules in Equinox are also raised into the Zodiax namespace through the `zod

```python
from equinox import nn
from zodiax.equinox import nn
from zodiax.eqx import nn
```

---

There are three methods from Equinox that are overwitten to give them a path based interface. These are `filter_grad`, `filter_value_and_grad`, and `partition`. Their usage can be seen in the 'usage' tutorials.

!!! info "Full API"
::: zodiax.equinox
::: zodiax.eqx
6 changes: 5 additions & 1 deletion docs/API/tree.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Tree

The Tree module provides a module for helpful pytree manipulation functions. It only implements a single function, `get_args(paths)`. It returns a matching pytree with boolean leaves, where the leaves specified by `paths` are `True` and the rest are `False`.
The Tree module provides a module for helpful pytree manipulation functions. It implements two functions, `boolean_filter(pytree, parameters)` and `set_array(pytree, parameters)`.

`boolean_filter(pytree, parameters)` returns a matching pytree with boolean leaves, where the leaves specified by `parameters` are `True` and the rest are `False`.

`set_array(pytree, parameters)` returns a matching pytree with the leaves specified by `parameters` set to the value of the corresponding leaf in `pytree`. This is to ensure they have a shape parameter in order to create dynamic array shapes for the bayesian module.

!!! info "Full API"
::: zodiax.tree
Binary file modified docs/assets/fisher_fit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 10 additions & 25 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ This class simply models a normal distribution with a mean, scale and amplitude,
distribution = normal(10)
```

This is a matter of personal preference, *however* when using Optax if you try to optimise a class that has a `.__call__()` method, you can thrown unhelpful errors. Becuase of this I recommend avoiding `.__call__()` methods and instead using `.model()` method.
This is a matter of personal preference, *however* when using Optax if you try to optimise a class that has a `.__call__()` method, you can thrown unhelpful errors. Becuase of this I recommend avoiding `.__call__()` methods and instead using `.model()` method. Similarly, the `bayes` module of zodiax uses the `.model()` method to evaluate the likelihood of the model, so it is best to use this method to avoid confusion!

Now we construct a class to store and model a set of multiple normals.

Expand Down Expand Up @@ -159,7 +159,7 @@ Since we have constructed the `__getattr__` method, these paths can be simplifie

!!! tip "Path Uniqueness"
Paths must be unique
Paths should not have space in them to work properly with the `__getattrr__`
Paths should not have spaces in them to work properly with the `__getattrr__`

### **Class Methods**

Expand Down Expand Up @@ -382,7 +382,7 @@ Easy! Lets examine the results

### Fisher Inference

The differentiable nature of Zodiax objects also allows us to perform inference on the parameters of our model. The [Laplace approximation](https://en.wikipedia.org/wiki/Laplace%27s_approximation) assumes that the posterior distribution of our model parameters is a gaussian distribution centred on the maximum likelihood estimate of the parameters. Luckily we can use autodiff to calculate the hessian of the log likelihood function and invert it to get the covariance matrix of the posterior distribution!
The differentiable nature of Zodiax objects also allows us to perform inference on the parameters of our model. The [Laplace approximation](https://en.wikipedia.org/wiki/Laplace%27s_approximation) assumes that the posterior distribution of our model parameters is a gaussian distribution centred on the maximum likelihood estimate of the parameters. Luckily we can use autodiff to calculate the hessian of the log likelihood function and invert it to get the covariance matrix of the posterior distribution! Zodiax has some inbuilt functions that can be used to calculate the covariance matrix of a model

??? info "Fisher and Covariance Matrices"
The covariance matrix $\vec{\Sigma}$ describes the covariance between the parameters of a model. Under the Laplace approximation, we can calculate the covariance matrix using autodiff:
Expand All @@ -401,24 +401,9 @@ parameters = ['alpha.mean', 'beta.mean',
'alpha.scale', 'beta.scale',
'alpha.amplitude', 'beta.amplitude']

# Define Likelihod function
def chi2(X, model, data, noise=1):
signal = perturb(X, model).model()
return np.log10(np.square((signal - data) / noise).sum())

# Define Perturbation function
def perturb(X, model):
for parameter, x in zip(parameters, X):
model = model.add(parameter, x)
return model

# Define Covariance function
def calculate_covariance(model, data):
X = np.zeros(len(parameters))
return -np.linalg.inv(jax.hessian(chi2)(X, model, data))

# Calcuate parameter variances
covariance_matrix = calculate_covariance(model, data)
# Get the covariance matrix
covariance_matrix = zdx.covariance_matrix(model, parameters,
zdx.chi2_loglike, data, noise=1/50)
deviations = np.abs(np.diag(covariance_matrix))**0.5
```

Expand Down Expand Up @@ -481,12 +466,12 @@ def sampling_fn(data, model):
# Sample from the posterior distribution
with npy.plate("data", len(data)):
model_sampler = dist.Normal(
model.set_and_call(paths, values, "model")
model.set(paths, values).model().flatten()
)
return npy.sample("Sampler", model_sampler, obs=data)
return npy.sample("Sampler", model_sampler, obs=data.flatten())
```

Numpyo requires a 'sampling' function where you assign priors to your parameters and then sample from the posterior distribution. The syntax for this can be seen above. We then sample the data using a 'plate' and define a likelihood which in this case is a normal. The `set_and_call` function is a Zodiax function that allows us to update the model parameters and then return call some method of that class. This is the function that ultimately allows a simple interface with Numpyro.
Numpyo requires a 'sampling' function where you assign priors to your parameters and then sample from the posterior distribution. The syntax for this can be seen above. We then sample the data using a 'plate' and define a likelihood which in this case is a normal.

We then need to define our sampler which in this case is the No U-Turn Sampler (NUTS). NUTS is a variant of Hamiltonian Monte Carlo (HMC) that is designed to be more efficient and robust, and takes advantage of gradients to allow high dimensional inference.

Expand Down Expand Up @@ -515,4 +500,4 @@ Fantastic now lets have a look at our posterior distributions!
fig.set_size_inches((15, 15))
```

![Numpyro](../assets/hmc_fit.png)
![Numpyro](../assets/hmc_fit.png)
11 changes: 6 additions & 5 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ nav:
- Using Zodiax: docs/usage.md

- API:
- Overview: docs/API/api.md
- Base: docs/API/base.md
- Equinox: docs/API/equinox.md
- Optimisation: docs/API/optimisation.md
- Tree: docs/API/tree.md
- Overview: docs/API/api.md
- Base: docs/API/base.md
- Tree: docs/API/tree.md
- Bayes: docs/API/bayes.md
- Equinox (eqx): docs/API/eqx.md
- Optimisation: docs/API/optimisation.md
- Serialisation: docs/API/serialisation.md

- FAQ & Troubleshooting: docs/faq.md
Expand Down
28 changes: 19 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@
import pytest


@pytest.fixture(scope='class')
def Base_instance(a=1., b=2.):
"""
Construct a Base instance for testing
"""
yield A(a, B(b))


class A(zodiax.base.Base):
"""
Test subclass to test the Base methods
Expand Down Expand Up @@ -43,4 +35,22 @@ def __init__(self, param):
"""
Constructor for the Base testing class
"""
self.param = param
self.param = param


# @pytest.fixture(scope='class')
# def create_Base(a=1., b=2.):
@pytest.fixture
def create_base():
"""
Construct a Base instance for testing
"""
def _create_base(
param : float = 1.,
b : float = 2.,
) -> zodiax.base.Base:
"""
Construct a Base instance for testing
"""
return A(param, B(b))
return _create_base
Loading

0 comments on commit 8aec3e0

Please sign in to comment.