From 38903eeeeb9e0f6c884654190ddd9613bccf9898 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:39:06 -0600 Subject: [PATCH] Add convergence test interim samples from image (#36) * function to simplify later steps * refactor, still not done until later pr * test draft * less stuff to run * bug fix * separate out slow and quick tests * add slow marker * flag I alwasy wante * register mark --- .github/workflows/tests.yml | 8 +++- bpd/pipelines/image_ellips.py | 16 +++++++- pyproject.toml | 3 +- scripts/one_galaxy_shear.py | 12 +----- tests/test_convergence.py | 71 ++++++++++++++++++++++++++++++++++- 5 files changed, 93 insertions(+), 17 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cc8ee22..6281d33 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,6 +35,10 @@ jobs: - name: Run Ruff run: ruff check --output-format=github . - - name: Run Tests + - name: Run fast tests run: | - pytest --durations=0 + pytest -m "not slow" --durations=0 + + - name: Run slow tests + run: | + pytest -m "slow" --durations=0 diff --git a/bpd/pipelines/image_ellips.py b/bpd/pipelines/image_ellips.py index cac4afe..fa62e03 100644 --- a/bpd/pipelines/image_ellips.py +++ b/bpd/pipelines/image_ellips.py @@ -10,7 +10,7 @@ from bpd.chains import run_inference_nuts from bpd.draw import draw_gaussian, draw_gaussian_galsim from bpd.noise import add_noise -from bpd.prior import ellip_mag_prior, sample_ellip_prior +from bpd.prior import ellip_mag_prior, sample_ellip_prior, scalar_shear_transformation def get_target_galaxy_params_simple( @@ -37,6 +37,18 @@ def get_target_galaxy_params_simple( } +def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]): + true_params = {**galaxy_params} + e1, e2 = true_params.pop("e1"), true_params.pop("e2") + g1, g2 = true_params.pop("g1"), true_params.pop("g2") + + e1_prime, e2_prime = scalar_shear_transformation((e1, e2), (g1, g2)) + true_params["e1"] = e1_prime + true_params["e2"] = e2_prime + + return true_params # don't add g1,g2 back as we are not inferring those + + def get_target_images_single( rng_key: PRNGKeyArray, n_samples: int, @@ -112,7 +124,7 @@ def pipeline_image_interim_samples_one_galaxy( max_num_doublings: int = 5, initial_step_size: float = 1e-3, n_warmup_steps: int = 500, - is_mass_matrix_diagonal: bool = False, + is_mass_matrix_diagonal: bool = True, slen: int = 53, fft_size: int = 256, background: float = 1.0, diff --git a/pyproject.toml b/pyproject.toml index 5fc2afe..e9e0692 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,5 +86,6 @@ exclude = ["*.ipynb", "scripts/one_galaxy_shear.py", "scripts/benchmarks/*.py"] [tool.pytest.ini_options] minversion = "6.0" -addopts = "-ra" +addopts = "-ra -v --strict-markers" filterwarnings = ["ignore::DeprecationWarning:tensorflow.*"] +markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] diff --git a/scripts/one_galaxy_shear.py b/scripts/one_galaxy_shear.py index 6e147b3..74b5531 100755 --- a/scripts/one_galaxy_shear.py +++ b/scripts/one_galaxy_shear.py @@ -11,10 +11,10 @@ from bpd.pipelines.image_ellips import ( get_target_galaxy_params_simple, get_target_images_single, + get_true_params_from_galaxy_params, pipeline_image_interim_samples, ) from bpd.pipelines.shear_inference import pipeline_shear_inference -from bpd.prior import scalar_shear_transformation init_fnc = init_with_truth @@ -53,19 +53,11 @@ def main( nkey, n_samples=n_gals, single_galaxy_params=galaxy_params, - psf_hlr=psf_hlr, background=background, slen=slen, - pixel_scale=pixel_scale, ) - true_params = {**galaxy_params} - e1, e2 = true_params.pop("e1"), true_params.pop("e2") - g1, g2 = true_params.pop("g1"), true_params.pop("g2") - - e1_prime, e2_prime = scalar_shear_transformation((e1, e2), (g1, g2)) - true_params["e1"] = e1_prime - true_params["e2"] = e2_prime + true_params = get_true_params_from_galaxy_params(galaxy_params) # prepare pipelines pipe1 = partial( diff --git a/tests/test_convergence.py b/tests/test_convergence.py index f2c4c25..b3d8d47 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -9,6 +9,13 @@ from jax import random, vmap from bpd.chains import run_inference_nuts +from bpd.initialization import init_with_truth +from bpd.pipelines.image_ellips import ( + get_target_galaxy_params_simple, + get_target_images_single, + get_true_params_from_galaxy_params, + pipeline_image_interim_samples_one_galaxy, +) from bpd.pipelines.shear_inference import pipeline_shear_inference from bpd.pipelines.toy_ellips import logtarget as logtarget_toy_ellips from bpd.pipelines.toy_ellips import pipeline_toy_ellips_samples @@ -16,7 +23,7 @@ @pytest.mark.parametrize("seed", [1234, 4567]) -def test_interim_ellipticity_posterior_convergence(seed): +def test_interim_toy_convergence(seed): """Check efficiency and convergence of chains for 100 galaxies.""" g1, g2 = 0.02, 0.0 sigma_m = 1e-4 @@ -74,7 +81,7 @@ def test_interim_ellipticity_posterior_convergence(seed): @pytest.mark.parametrize("seed", [1234, 4567]) -def test_shear_posterior_convergence(seed): +def test_toy_shear_convergence(seed): g1, g2 = 0.02, 0.0 sigma_m = 1e-4 sigma_e = 1e-3 @@ -124,3 +131,63 @@ def test_shear_posterior_convergence(seed): assert ess > 0.5 * 4000 assert jnp.abs(rhat - 1) < 0.01 + + +@pytest.mark.slow +@pytest.mark.parametrize("seed", [1234, 4567]) +def test_low_noise_single_galaxy_interim_samples(seed): + lf = 6.0 + hlr = 1.0 + g1, g2 = 0.02, 0.0 + sigma_e = 1e-3 + sigma_e_int = 3e-2 + n_samples = 500 + background = 1.0 + slen = 53 + fft_size = 256 + init_fnc = init_with_truth + + rng_key = random.key(seed) + pkey, nkey, gkey = random.split(rng_key, 3) + + galaxy_params = get_target_galaxy_params_simple( + pkey, lf=lf, g1=g1, g2=g2, hlr=hlr, shape_noise=sigma_e + ) + + draw_params = {**galaxy_params} + draw_params["f"] = 10 ** draw_params.pop("lf") + target_image = get_target_images_single( + nkey, + n_samples=1, + single_galaxy_params=draw_params, + background=background, + slen=slen, + )[0] + true_params = get_true_params_from_galaxy_params(galaxy_params) + + pipe1 = partial( + pipeline_image_interim_samples_one_galaxy, + initialization_fnc=init_fnc, + sigma_e_int=sigma_e_int, + n_samples=n_samples, + slen=slen, + fft_size=fft_size, + n_warmup_steps=300, + ) + vpipe1 = vmap(jjit(pipe1), (0, 0, None)) + + # chain initialization + # one galaxy, test convergence, so 4 random seeds + keys = random.split(gkey, 4) + init_positions = vmap(init_fnc, (0, None))(keys, true_params) + + samples = vpipe1(keys, init_positions, target_image) + + # check each component + for _, v in samples.items(): + assert v.shape == (4, n_samples) + ess = effective_sample_size(v) + rhat = potential_scale_reduction(v) + + assert ess > 0.5 * n_samples + assert jnp.abs(rhat - 1) < 0.01