Skip to content

Commit

Permalink
Fix PyMC and PyTensor deprecation warnings (#130)
Browse files Browse the repository at this point in the history
* Add tests for CeleriteNormalRV random variable and support point

* Update `CeleriteNormalRV` to use signature

* Replace `moment` by `support_point` in `CeleriteNormalRV`

* Add coments to test

Needed a commit to trigger Github Actions
  • Loading branch information
vandalt authored Jun 29, 2024
1 parent 12a4a8f commit a01d20d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 9 deletions.
20 changes: 11 additions & 9 deletions python/celerite2/pymc/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def safe_celerite_normal(rng, mean, norm, t, c, U, W, d, size=None):

class CeleriteNormalRV(RandomVariable):
name = "celerite_normal"
ndim_supp = 1
ndims_params = [1, 0, 1, 1, 2, 2, 1]
signature = "(i_mean),(),(i_t),(i_c),(i_U1,i_U2),(i_W1,i_W2),(i_d)->(i)"
dtype = "floatX"
_print_name = ("CeleriteNormal", "\\operatorname{CeleriteNormal}")

Expand All @@ -46,12 +45,15 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):

@classmethod
def rng_fn(cls, rng, mean, norm, t, c, U, W, d, size):
# Hardcoded because no longer set as class attribute
# Ref: https://github.com/pymc-devs/pytensor/issues/866
ndims_params = [1, 0, 1, 1, 2, 2, 1]
if any(
x.ndim > n
for n, x in zip(cls.ndims_params, [mean, norm, t, c, U, W, d])
for n, x in zip(ndims_params, [mean, norm, t, c, U, W, d])
):
mean, norm, t, c, U, W, d = broadcast_params(
[mean, norm, t, c, U, W, d], cls.ndims_params
[mean, norm, t, c, U, W, d], ndims_params
)
size = tuple(size or ())

Expand Down Expand Up @@ -111,12 +113,12 @@ def dist(cls, mean, norm, t, c, U, W, d, **kwargs):
mean = pt.broadcast_arrays(mean, t)[0]
return super().dist([mean, norm, t, c, U, W, d], **kwargs)

def moment(rv, size, mean, *args):
moment = mean
def support_point(rv, size, mean, *args):
support_point = mean
if not rv_size_is_none(size):
moment_size = pt.concatenate([size, [mean.shape[-1]]])
moment = pt.full(moment_size, mean)
return moment
support_size = pt.concatenate([size, [mean.shape[-1]]])
support_point = pt.full(support_size, mean)
return support_point

def logp(value, mean, norm, t, c, U, W, d):
ok = pt.all(pt.gt(d, 0.0))
Expand Down
44 changes: 44 additions & 0 deletions python/test/pymc/test_pymc_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import pytest

pytest.importorskip("celerite2.pymc")

try:
from pymc.testing import assert_support_point_is_expected

from celerite2.pymc import GaussianProcess, terms
from celerite2.pymc.distribution import CeleriteNormalRV
except (ImportError, ModuleNotFoundError):
pass


def test_celerite_normal_rv():
# Test that ndims_params and ndim_supp have the expected value
# now that they are created from signature
celerite_normal = CeleriteNormalRV()
assert celerite_normal.ndim_supp == 1
assert tuple(celerite_normal.ndims_params) == (1, 0, 1, 1, 2, 2, 1)


@pytest.mark.parametrize(
"t, mean, size, expected",
[
(np.arange(5, dtype=float), 0.0, None, np.full(5, 0.0)),
(
np.arange(5, dtype=float),
np.arange(5, dtype=float),
None,
np.arange(5, dtype=float),
),
],
)
def test_celerite_normal_support_point(t, mean, size, expected):
# Test that support point has the expected shape and value
pm = pytest.importorskip("pymc")

with pm.Model() as model:
term = terms.SHOTerm(S0=1.0, w0=0.5, Q=3.0)
gp = GaussianProcess(term, t=t, mean=mean)
# NOTE: Name must be "x" for assert function to work
gp.marginal("x", size=size)
assert_support_point_is_expected(model, expected)

0 comments on commit a01d20d

Please sign in to comment.