From 363ea57cf47e5b054ffeb93249f67f2e59857a30 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Mon, 4 Nov 2024 13:36:21 -0800 Subject: [PATCH 01/12] function to simplify later steps --- bpd/pipelines/image_ellips.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/bpd/pipelines/image_ellips.py b/bpd/pipelines/image_ellips.py index cac4afe..fa62e03 100644 --- a/bpd/pipelines/image_ellips.py +++ b/bpd/pipelines/image_ellips.py @@ -10,7 +10,7 @@ from bpd.chains import run_inference_nuts from bpd.draw import draw_gaussian, draw_gaussian_galsim from bpd.noise import add_noise -from bpd.prior import ellip_mag_prior, sample_ellip_prior +from bpd.prior import ellip_mag_prior, sample_ellip_prior, scalar_shear_transformation def get_target_galaxy_params_simple( @@ -37,6 +37,18 @@ def get_target_galaxy_params_simple( } +def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]): + true_params = {**galaxy_params} + e1, e2 = true_params.pop("e1"), true_params.pop("e2") + g1, g2 = true_params.pop("g1"), true_params.pop("g2") + + e1_prime, e2_prime = scalar_shear_transformation((e1, e2), (g1, g2)) + true_params["e1"] = e1_prime + true_params["e2"] = e2_prime + + return true_params # don't add g1,g2 back as we are not inferring those + + def get_target_images_single( rng_key: PRNGKeyArray, n_samples: int, @@ -112,7 +124,7 @@ def pipeline_image_interim_samples_one_galaxy( max_num_doublings: int = 5, initial_step_size: float = 1e-3, n_warmup_steps: int = 500, - is_mass_matrix_diagonal: bool = False, + is_mass_matrix_diagonal: bool = True, slen: int = 53, fft_size: int = 256, background: float = 1.0, From a51d9b3a9cc5394a513943aeba5697d8b9a08ebc Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Mon, 4 Nov 2024 13:36:32 -0800 Subject: [PATCH 02/12] refactor, still not done until later pr --- scripts/one_galaxy_shear.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/scripts/one_galaxy_shear.py b/scripts/one_galaxy_shear.py index 6e147b3..74b5531 100755 --- a/scripts/one_galaxy_shear.py +++ b/scripts/one_galaxy_shear.py @@ -11,10 +11,10 @@ from bpd.pipelines.image_ellips import ( get_target_galaxy_params_simple, get_target_images_single, + get_true_params_from_galaxy_params, pipeline_image_interim_samples, ) from bpd.pipelines.shear_inference import pipeline_shear_inference -from bpd.prior import scalar_shear_transformation init_fnc = init_with_truth @@ -53,19 +53,11 @@ def main( nkey, n_samples=n_gals, single_galaxy_params=galaxy_params, - psf_hlr=psf_hlr, background=background, slen=slen, - pixel_scale=pixel_scale, ) - true_params = {**galaxy_params} - e1, e2 = true_params.pop("e1"), true_params.pop("e2") - g1, g2 = true_params.pop("g1"), true_params.pop("g2") - - e1_prime, e2_prime = scalar_shear_transformation((e1, e2), (g1, g2)) - true_params["e1"] = e1_prime - true_params["e2"] = e2_prime + true_params = get_true_params_from_galaxy_params(galaxy_params) # prepare pipelines pipe1 = partial( From 0442294027abef13a26063de71b9add520c1c234 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Mon, 4 Nov 2024 13:36:37 -0800 Subject: [PATCH 03/12] test draft --- tests/test_convergence.py | 69 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/tests/test_convergence.py b/tests/test_convergence.py index f2c4c25..f70b39c 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -9,6 +9,13 @@ from jax import random, vmap from bpd.chains import run_inference_nuts +from bpd.initialization import init_with_truth +from bpd.pipelines.image_ellips import ( + get_target_galaxy_params_simple, + get_target_images_single, + get_true_params_from_galaxy_params, + pipeline_image_interim_samples_one_galaxy, +) from bpd.pipelines.shear_inference import pipeline_shear_inference from bpd.pipelines.toy_ellips import logtarget as logtarget_toy_ellips from bpd.pipelines.toy_ellips import pipeline_toy_ellips_samples @@ -16,7 +23,7 @@ @pytest.mark.parametrize("seed", [1234, 4567]) -def test_interim_ellipticity_posterior_convergence(seed): +def test_interim_toy_convergence(seed): """Check efficiency and convergence of chains for 100 galaxies.""" g1, g2 = 0.02, 0.0 sigma_m = 1e-4 @@ -74,7 +81,7 @@ def test_interim_ellipticity_posterior_convergence(seed): @pytest.mark.parametrize("seed", [1234, 4567]) -def test_shear_posterior_convergence(seed): +def test_toy_shear_convergence(seed): g1, g2 = 0.02, 0.0 sigma_m = 1e-4 sigma_e = 1e-3 @@ -124,3 +131,61 @@ def test_shear_posterior_convergence(seed): assert ess > 0.5 * 4000 assert jnp.abs(rhat - 1) < 0.01 + + +@pytest.mark.parametrize("seed", [1234, 4567, 1111, 2222]) +def test_low_noise_single_galaxy_interim_samples(seed): + lf = 6.0 + hlr = 1.0 + g1, g2 = 0.02, 0.0 + sigma_e = 1e-3 + sigma_e_int = 3e-2 + n_samples = 1000 + background = 1.0 + slen = 53 + fft_size = 256 + init_fnc = init_with_truth + + rng_key = random.key(seed) + pkey, nkey, gkey = random.split(rng_key, 3) + + galaxy_params = get_target_galaxy_params_simple( + pkey, lf=lf, g1=g1, g2=g2, hlr=hlr, shape_noise=sigma_e + ) + + draw_params = {**galaxy_params} + draw_params["f"] = 10 ** draw_params.pop("lf") + target_image = get_target_images_single( + nkey, + n_samples=1, + single_galaxy_params=draw_params, + background=background, + slen=slen, + )[0] + true_params = get_true_params_from_galaxy_params(galaxy_params) + + pipe1 = partial( + pipeline_image_interim_samples_one_galaxy, + initialization_fnc=init_fnc, + sigma_e_int=sigma_e_int, + n_samples=n_samples, + slen=slen, + fft_size=fft_size, + ) + vpipe1 = vmap(jjit(pipe1), (0, 0, None)) + + # chain initialization + # one galaxy, test convergence, so 4 random seeds + keys = random.split(gkey, 4) + init_positions = vmap(init_fnc, (0, None))(keys, true_params) + + samples = vpipe1(keys, init_positions, target_image) + + # check each component + for _, v in samples.item(): + assert v.shape == (4, 1000) + ess = effective_sample_size(v) + rhat = potential_scale_reduction(v) + + assert ess > 0.5 * 4000 + assert jnp.abs(rhat - 1) < 0.01 From 601be04489b1f7744c3a781f6cc72a10f095e362 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Mon, 4 Nov 2024 13:46:49 -0800 Subject: [PATCH 04/12] less stuff to run --- tests/test_convergence.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_convergence.py b/tests/test_convergence.py index f70b39c..b7b3690 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -133,14 +133,14 @@ def test_toy_shear_convergence(seed): assert jnp.abs(rhat - 1) < 0.01 -@pytest.mark.parametrize("seed", [1234, 4567, 1111, 2222]) +@pytest.mark.parametrize("seed", [1234, 4567]) def test_low_noise_single_galaxy_interim_samples(seed): lf = 6.0 hlr = 1.0 g1, g2 = 0.02, 0.0 sigma_e = 1e-3 sigma_e_int = 3e-2 - n_samples = 1000 + n_samples = 500 background = 1.0 slen = 53 fft_size = 256 @@ -171,6 +171,7 @@ def test_low_noise_single_galaxy_interim_samples(seed): n_samples=n_samples, slen=slen, fft_size=fft_size, + n_warmup_steps=300, ) vpipe1 = vmap(jjit(pipe1), (0, 0, None)) @@ -183,9 +184,9 @@ def test_low_noise_single_galaxy_interim_samples(seed): # check each component for _, v in samples.item(): - assert v.shape == (4, 1000) + assert v.shape == (4, n_samples) ess = effective_sample_size(v) rhat = potential_scale_reduction(v) - assert ess > 0.5 * 4000 + assert ess > 0.5 * n_samples assert jnp.abs(rhat - 1) < 0.01 From f4fb1e9c1131c56253fea04b214c0c038abcd184 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Mon, 4 Nov 2024 14:36:01 -0800 Subject: [PATCH 05/12] bug fix --- tests/test_convergence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_convergence.py b/tests/test_convergence.py index b7b3690..00f87a2 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -183,7 +183,7 @@ def test_low_noise_single_galaxy_interim_samples(seed): samples = vpipe1(keys, init_positions, target_image) # check each component - for _, v in samples.item(): + for _, v in samples.items(): assert v.shape == (4, n_samples) ess = effective_sample_size(v) rhat = potential_scale_reduction(v) From b688e8759a9179c5f1e6fd1f1a9864a2b05155e5 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 5 Nov 2024 11:56:16 -0800 Subject: [PATCH 06/12] separate out slow and quick tests --- .github/workflows/tests.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cc8ee22..6281d33 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,6 +35,10 @@ jobs: - name: Run Ruff run: ruff check --output-format=github . - - name: Run Tests + - name: Run fast tests run: | - pytest --durations=0 + pytest -m "not slow" --durations=0 + + - name: Run slow tests + run: | + pytest -m "slow" --durations=0 From 80d1d8ed34128d91c4d2c50d75f681585fdc01f3 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 5 Nov 2024 11:57:11 -0800 Subject: [PATCH 07/12] add slow marker --- tests/test_convergence.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_convergence.py b/tests/test_convergence.py index 00f87a2..b3d8d47 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -133,6 +133,7 @@ def test_toy_shear_convergence(seed): assert jnp.abs(rhat - 1) < 0.01 +@pytest.mark.slow @pytest.mark.parametrize("seed", [1234, 4567]) def test_low_noise_single_galaxy_interim_samples(seed): lf = 6.0 From e67657a0d78fdd1b741fddc8aa19269a07dec1cd Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 5 Nov 2024 11:57:16 -0800 Subject: [PATCH 08/12] flag I alwasy wante --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5fc2afe..d569a09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,5 +86,5 @@ exclude = ["*.ipynb", "scripts/one_galaxy_shear.py", "scripts/benchmarks/*.py"] [tool.pytest.ini_options] minversion = "6.0" -addopts = "-ra" +addopts = "-ra -v" filterwarnings = ["ignore::DeprecationWarning:tensorflow.*"] From c81cde4949ead8d26ea4ab0cad62aca1c3e9fe77 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 5 Nov 2024 12:05:21 -0800 Subject: [PATCH 09/12] register mark --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d569a09..e9e0692 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,5 +86,6 @@ exclude = ["*.ipynb", "scripts/one_galaxy_shear.py", "scripts/benchmarks/*.py"] [tool.pytest.ini_options] minversion = "6.0" -addopts = "-ra -v" +addopts = "-ra -v --strict-markers" filterwarnings = ["ignore::DeprecationWarning:tensorflow.*"] +markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] From 13fd9797cd80f2d06a4317f3aa2ac1eadb4ba2bc Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 5 Nov 2024 12:17:02 -0800 Subject: [PATCH 10/12] will split up into two processes, as first one is quite expensive --- scripts/get_shear_from_interim_samples.py | 64 +++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100755 scripts/get_shear_from_interim_samples.py diff --git a/scripts/get_shear_from_interim_samples.py b/scripts/get_shear_from_interim_samples.py new file mode 100755 index 0000000..6005406 --- /dev/null +++ b/scripts/get_shear_from_interim_samples.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +"""This file creates toy samples of ellipticities and saves them to .hdf5 file.""" + +from pathlib import Path + +import jax +import jax.numpy as jnp +import typer + +from bpd import DATA_DIR +from bpd.io import load_dataset +from bpd.pipelines.shear_inference import pipeline_shear_inference + + +def _extract_seed(fpath: str) -> int: + name = Path(fpath).name + first = name.find("_") + second = name.find("_", first + 1) + third = name.find(".") + return int(name[second + 1 : third]) + + +def main( + seed: int, + tag: str, + interim_samples_fname: str, + sigma_e_int: float = 3e-2, + initial_step_size: float = 1e-3, + n_samples: int = 3000, + trim: int = 1, + overwrite: bool = False, +): + # directory structure + dirpath = DATA_DIR / "cache_chains" / tag + assert dirpath.exists() + interim_samples_fpath = DATA_DIR / "cache_chains" / tag / interim_samples_fname + assert interim_samples_fpath.exists(), "ellipticity samples file does not exist" + old_seed = _extract_seed(interim_samples_fpath) + fpath = DATA_DIR / "cache_chains" / tag / f"g_samples_{old_seed}_{seed}.npy" + + if fpath.exists() and not overwrite: + raise IOError("overwriting...") + + samples_dataset = load_dataset(interim_samples_fpath) + e_post = samples_dataset["e_post"][:, ::trim, :] + true_g = samples_dataset["true_g"] + sigma_e = samples_dataset["sigma_e"] + + rng_key = jax.random.key(seed) + g_samples = pipeline_shear_inference( + rng_key, + e_post, + true_g=true_g, + sigma_e=sigma_e, + sigma_e_int=sigma_e_int, + initial_step_size=initial_step_size, + n_samples=n_samples, + ) + + jnp.save(fpath, g_samples) + + +if __name__ == "__main__": + typer.run(main) From 6896899e0733333dba8625b687a6c28002255550 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 5 Nov 2024 12:40:45 -0800 Subject: [PATCH 11/12] various test fixes --- tests/test_convergence.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/test_convergence.py b/tests/test_convergence.py index b3d8d47..141dc58 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -156,13 +156,15 @@ def test_low_noise_single_galaxy_interim_samples(seed): draw_params = {**galaxy_params} draw_params["f"] = 10 ** draw_params.pop("lf") - target_image = get_target_images_single( + target_image, _ = get_target_images_single( nkey, n_samples=1, single_galaxy_params=draw_params, background=background, slen=slen, - )[0] + ) + assert target_image.shape == (1, slen, slen) + true_params = get_true_params_from_galaxy_params(galaxy_params) pipe1 = partial( @@ -178,10 +180,13 @@ def test_low_noise_single_galaxy_interim_samples(seed): # chain initialization # one galaxy, test convergence, so 4 random seeds - keys = random.split(gkey, 4) - init_positions = vmap(init_fnc, (0, None))(keys, true_params) + gkey1, gkey2 = random.split(gkey, 2) + keys1 = random.split(gkey1, 4) + keys2 = random.split(gkey2, 4) + + init_positions = vmap(init_fnc, (0, None))(keys1, true_params) - samples = vpipe1(keys, init_positions, target_image) + samples = vpipe1(keys2, init_positions, target_image) # check each component for _, v in samples.items(): From 5311d8a718b58939beb98627e867a1931c66c0b1 Mon Sep 17 00:00:00 2001 From: ismael2395 Date: Tue, 5 Nov 2024 12:43:32 -0800 Subject: [PATCH 12/12] continue draft --- ...py => one_galaxy_image_interim_samples.py} | 55 ++++++++++--------- 1 file changed, 28 insertions(+), 27 deletions(-) rename scripts/{one_galaxy_shear.py => one_galaxy_image_interim_samples.py} (71%) diff --git a/scripts/one_galaxy_shear.py b/scripts/one_galaxy_image_interim_samples.py similarity index 71% rename from scripts/one_galaxy_shear.py rename to scripts/one_galaxy_image_interim_samples.py index 74b5531..3fc542a 100755 --- a/scripts/one_galaxy_shear.py +++ b/scripts/one_galaxy_image_interim_samples.py @@ -12,9 +12,8 @@ get_target_galaxy_params_simple, get_target_images_single, get_true_params_from_galaxy_params, - pipeline_image_interim_samples, + pipeline_image_interim_samples_one_galaxy, ) -from bpd.pipelines.shear_inference import pipeline_shear_inference init_fnc = init_with_truth @@ -22,6 +21,9 @@ def main( tag: str, seed: int, + n_gals: int = 100, # technically, in this file it means 'noise realizations' + n_samples_per_gal: int = 100, + n_vec: int = 50, # how many galaxies to process simultaneously in 1 GPU core g1: float = 0.02, g2: float = 0.0, lf: float = 6.0, @@ -30,13 +32,11 @@ def main( slen: int = 53, fft_size: int = 256, background: float = 1.0, - n_gals: int = 1000, # technically, here it means 'noise realizations' - n_samples_shear: int = 3000, - n_samples_per_gal: int = 100, + initial_step_size: float = 1e-3, trim: int = 1, ): rng_key = random.key(seed) - pkey, nkey, gkey, skey = random.split(rng_key, 3) + pkey, nkey, gkey = random.split(rng_key, 3) # directory structure dirpath = DATA_DIR / "cache_chains" / tag @@ -44,51 +44,52 @@ def main( if not dirpath.exists(): dirpath.mkdir(exist_ok=True) + fpath = dirpath / f"e_post_{seed}.npy" + # get images - galaxy_params = get_target_galaxy_params_simple( - pkey, lf=lf, g1=g1, g2=g2, hlr=hlr, shape_noise=shape_noise + galaxy_params = get_target_galaxy_params_simple( # default hlr, x, y + pkey, lf=lf, g1=g1, g2=g2, shape_noise=shape_noise ) - target_images = get_target_images_single( + draw_params = {**galaxy_params} + draw_params["f"] = 10 ** draw_params.pop("lf") + target_images, _ = get_target_images_single( nkey, n_samples=n_gals, - single_galaxy_params=galaxy_params, + single_galaxy_params=draw_params, background=background, slen=slen, ) + assert target_images.shape == (n_gals, slen, slen) true_params = get_true_params_from_galaxy_params(galaxy_params) # prepare pipelines pipe1 = partial( - pipeline_image_interim_samples, + pipeline_image_interim_samples_one_galaxy, initialization_fnc=init_fnc, - n_samples=k, - max_num_doublings=5, - initial_step_size=1e-3, - n_warmup_steps=500, - is_mass_matrix_diagonal=True, - background=background, + sigma_e_int=sigma_e_int, + n_samples=n_samples_per_gal, + initial_step_size=initial_step_size, slen=slen, - pixel_scale=pixel_scale, fft_size=fft_size, + background=background, ) vpipe1 = vmap(jjit(pipe1), (0, None, 0)) - pipe2 = partial( - pipeline_shear_inference, - true_g=jnp.array([g1, g2]), - sigma_e=shape_noise, - sigma_e_int=sigma_e_int, - n_samples=n_samples_shear, - ) - vpipe2 = vmap(pipe2, in_axes=(0, 0)) - + # initialization gkeys = random.split(gkey, n_gals) + init_positions = vmap(init_fnc, (0, None))(keys, true_params) + + galaxy_samples = vpipe1(gkeys, true_params, target_images) + + e_post = jnp.stack([galaxy_samples["e1"], galaxy_samples["e2"]], axis=-1) + jnp.save( + g_samples = vpipe2(skey, e_post)