diff --git a/python/celerite2/pymc/distribution.py b/python/celerite2/pymc/distribution.py index 1583772..629925b 100644 --- a/python/celerite2/pymc/distribution.py +++ b/python/celerite2/pymc/distribution.py @@ -31,8 +31,7 @@ def safe_celerite_normal(rng, mean, norm, t, c, U, W, d, size=None): class CeleriteNormalRV(RandomVariable): name = "celerite_normal" - ndim_supp = 1 - ndims_params = [1, 0, 1, 1, 2, 2, 1] + signature = "(i_mean),(),(i_t),(i_c),(i_U1,i_U2),(i_W1,i_W2),(i_d)->(i)" dtype = "floatX" _print_name = ("CeleriteNormal", "\\operatorname{CeleriteNormal}") @@ -46,12 +45,15 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None): @classmethod def rng_fn(cls, rng, mean, norm, t, c, U, W, d, size): + # Hardcoded because no longer set as class attribute + # Ref: https://github.com/pymc-devs/pytensor/issues/866 + ndims_params = [1, 0, 1, 1, 2, 2, 1] if any( x.ndim > n - for n, x in zip(cls.ndims_params, [mean, norm, t, c, U, W, d]) + for n, x in zip(ndims_params, [mean, norm, t, c, U, W, d]) ): mean, norm, t, c, U, W, d = broadcast_params( - [mean, norm, t, c, U, W, d], cls.ndims_params + [mean, norm, t, c, U, W, d], ndims_params ) size = tuple(size or ()) @@ -111,12 +113,12 @@ def dist(cls, mean, norm, t, c, U, W, d, **kwargs): mean = pt.broadcast_arrays(mean, t)[0] return super().dist([mean, norm, t, c, U, W, d], **kwargs) - def moment(rv, size, mean, *args): - moment = mean + def support_point(rv, size, mean, *args): + support_point = mean if not rv_size_is_none(size): - moment_size = pt.concatenate([size, [mean.shape[-1]]]) - moment = pt.full(moment_size, mean) - return moment + support_size = pt.concatenate([size, [mean.shape[-1]]]) + support_point = pt.full(support_size, mean) + return support_point def logp(value, mean, norm, t, c, U, W, d): ok = pt.all(pt.gt(d, 0.0)) diff --git a/python/test/pymc/test_pymc_distribution.py b/python/test/pymc/test_pymc_distribution.py new file mode 100644 index 0000000..24a70f3 --- /dev/null +++ b/python/test/pymc/test_pymc_distribution.py @@ -0,0 +1,44 @@ +import numpy as np +import pytest + +pytest.importorskip("celerite2.pymc") + +try: + from pymc.testing import assert_support_point_is_expected + + from celerite2.pymc import GaussianProcess, terms + from celerite2.pymc.distribution import CeleriteNormalRV +except (ImportError, ModuleNotFoundError): + pass + + +def test_celerite_normal_rv(): + # Test that ndims_params and ndim_supp have the expected value + # now that they are created from signature + celerite_normal = CeleriteNormalRV() + assert celerite_normal.ndim_supp == 1 + assert tuple(celerite_normal.ndims_params) == (1, 0, 1, 1, 2, 2, 1) + + +@pytest.mark.parametrize( + "t, mean, size, expected", + [ + (np.arange(5, dtype=float), 0.0, None, np.full(5, 0.0)), + ( + np.arange(5, dtype=float), + np.arange(5, dtype=float), + None, + np.arange(5, dtype=float), + ), + ], +) +def test_celerite_normal_support_point(t, mean, size, expected): + # Test that support point has the expected shape and value + pm = pytest.importorskip("pymc") + + with pm.Model() as model: + term = terms.SHOTerm(S0=1.0, w0=0.5, Q=3.0) + gp = GaussianProcess(term, t=t, mean=mean) + # NOTE: Name must be "x" for assert function to work + gp.marginal("x", size=size) + assert_support_point_is_expected(model, expected)