Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

demonstrate shear inference on single galaxy (multiple noise realizations) #37

Closed
wants to merge 12 commits into from
8 changes: 6 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ jobs:
- name: Run Ruff
run: ruff check --output-format=github .

- name: Run Tests
- name: Run fast tests
run: |
pytest --durations=0
pytest -m "not slow" --durations=0

- name: Run slow tests
run: |
pytest -m "slow" --durations=0
16 changes: 14 additions & 2 deletions bpd/pipelines/image_ellips.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from bpd.chains import run_inference_nuts
from bpd.draw import draw_gaussian, draw_gaussian_galsim
from bpd.noise import add_noise
from bpd.prior import ellip_mag_prior, sample_ellip_prior
from bpd.prior import ellip_mag_prior, sample_ellip_prior, scalar_shear_transformation


def get_target_galaxy_params_simple(
Expand All @@ -37,6 +37,18 @@ def get_target_galaxy_params_simple(
}


def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]):
true_params = {**galaxy_params}
e1, e2 = true_params.pop("e1"), true_params.pop("e2")
g1, g2 = true_params.pop("g1"), true_params.pop("g2")

e1_prime, e2_prime = scalar_shear_transformation((e1, e2), (g1, g2))
true_params["e1"] = e1_prime
true_params["e2"] = e2_prime

return true_params # don't add g1,g2 back as we are not inferring those


def get_target_images_single(
rng_key: PRNGKeyArray,
n_samples: int,
Expand Down Expand Up @@ -112,7 +124,7 @@ def pipeline_image_interim_samples_one_galaxy(
max_num_doublings: int = 5,
initial_step_size: float = 1e-3,
n_warmup_steps: int = 500,
is_mass_matrix_diagonal: bool = False,
is_mass_matrix_diagonal: bool = True,
slen: int = 53,
fft_size: int = 256,
background: float = 1.0,
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,6 @@ exclude = ["*.ipynb", "scripts/one_galaxy_shear.py", "scripts/benchmarks/*.py"]

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra"
addopts = "-ra -v --strict-markers"
filterwarnings = ["ignore::DeprecationWarning:tensorflow.*"]
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
64 changes: 64 additions & 0 deletions scripts/get_shear_from_interim_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env python3
"""This file creates toy samples of ellipticities and saves them to .hdf5 file."""

from pathlib import Path

import jax
import jax.numpy as jnp
import typer

from bpd import DATA_DIR
from bpd.io import load_dataset
from bpd.pipelines.shear_inference import pipeline_shear_inference


def _extract_seed(fpath: str) -> int:
name = Path(fpath).name
first = name.find("_")
second = name.find("_", first + 1)
third = name.find(".")
return int(name[second + 1 : third])


def main(
seed: int,
tag: str,
interim_samples_fname: str,
sigma_e_int: float = 3e-2,
initial_step_size: float = 1e-3,
n_samples: int = 3000,
trim: int = 1,
overwrite: bool = False,
):
# directory structure
dirpath = DATA_DIR / "cache_chains" / tag
assert dirpath.exists()
interim_samples_fpath = DATA_DIR / "cache_chains" / tag / interim_samples_fname
assert interim_samples_fpath.exists(), "ellipticity samples file does not exist"
old_seed = _extract_seed(interim_samples_fpath)
fpath = DATA_DIR / "cache_chains" / tag / f"g_samples_{old_seed}_{seed}.npy"

if fpath.exists() and not overwrite:
raise IOError("overwriting...")

samples_dataset = load_dataset(interim_samples_fpath)
e_post = samples_dataset["e_post"][:, ::trim, :]
true_g = samples_dataset["true_g"]
sigma_e = samples_dataset["sigma_e"]

rng_key = jax.random.key(seed)
g_samples = pipeline_shear_inference(
rng_key,
e_post,
true_g=true_g,
sigma_e=sigma_e,
sigma_e_int=sigma_e_int,
initial_step_size=initial_step_size,
n_samples=n_samples,
)

jnp.save(fpath, g_samples)


if __name__ == "__main__":
typer.run(main)
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@
from bpd.pipelines.image_ellips import (
get_target_galaxy_params_simple,
get_target_images_single,
pipeline_image_interim_samples,
get_true_params_from_galaxy_params,
pipeline_image_interim_samples_one_galaxy,
)
from bpd.pipelines.shear_inference import pipeline_shear_inference
from bpd.prior import scalar_shear_transformation

init_fnc = init_with_truth


def main(
tag: str,
seed: int,
n_gals: int = 100, # technically, in this file it means 'noise realizations'
n_samples_per_gal: int = 100,
n_vec: int = 50, # how many galaxies to process simultaneously in 1 GPU core
g1: float = 0.02,
g2: float = 0.0,
lf: float = 6.0,
Expand All @@ -30,73 +32,64 @@ def main(
slen: int = 53,
fft_size: int = 256,
background: float = 1.0,
n_gals: int = 1000, # technically, here it means 'noise realizations'
n_samples_shear: int = 3000,
n_samples_per_gal: int = 100,
initial_step_size: float = 1e-3,
trim: int = 1,
):
rng_key = random.key(seed)
pkey, nkey, gkey, skey = random.split(rng_key, 3)
pkey, nkey, gkey = random.split(rng_key, 3)

# directory structure
dirpath = DATA_DIR / "cache_chains" / tag

if not dirpath.exists():
dirpath.mkdir(exist_ok=True)

fpath = dirpath / f"e_post_{seed}.npy"

# get images
galaxy_params = get_target_galaxy_params_simple(
pkey, lf=lf, g1=g1, g2=g2, hlr=hlr, shape_noise=shape_noise
galaxy_params = get_target_galaxy_params_simple( # default hlr, x, y
pkey, lf=lf, g1=g1, g2=g2, shape_noise=shape_noise
)

target_images = get_target_images_single(
draw_params = {**galaxy_params}
draw_params["f"] = 10 ** draw_params.pop("lf")
target_images, _ = get_target_images_single(
nkey,
n_samples=n_gals,
single_galaxy_params=galaxy_params,
psf_hlr=psf_hlr,
single_galaxy_params=draw_params,
background=background,
slen=slen,
pixel_scale=pixel_scale,
)
assert target_images.shape == (n_gals, slen, slen)

true_params = {**galaxy_params}
e1, e2 = true_params.pop("e1"), true_params.pop("e2")
g1, g2 = true_params.pop("g1"), true_params.pop("g2")

e1_prime, e2_prime = scalar_shear_transformation((e1, e2), (g1, g2))
true_params["e1"] = e1_prime
true_params["e2"] = e2_prime
true_params = get_true_params_from_galaxy_params(galaxy_params)

# prepare pipelines
pipe1 = partial(
pipeline_image_interim_samples,
pipeline_image_interim_samples_one_galaxy,
initialization_fnc=init_fnc,
n_samples=k,
max_num_doublings=5,
initial_step_size=1e-3,
n_warmup_steps=500,
is_mass_matrix_diagonal=True,
background=background,
sigma_e_int=sigma_e_int,
n_samples=n_samples_per_gal,
initial_step_size=initial_step_size,
slen=slen,
pixel_scale=pixel_scale,
fft_size=fft_size,
background=background,
)
vpipe1 = vmap(jjit(pipe1), (0, None, 0))

pipe2 = partial(
pipeline_shear_inference,
true_g=jnp.array([g1, g2]),
sigma_e=shape_noise,
sigma_e_int=sigma_e_int,
n_samples=n_samples_shear,
)
vpipe2 = vmap(pipe2, in_axes=(0, 0))

# initialization
gkeys = random.split(gkey, n_gals)
init_positions = vmap(init_fnc, (0, None))(keys, true_params)


galaxy_samples = vpipe1(gkeys, true_params, target_images)



e_post = jnp.stack([galaxy_samples["e1"], galaxy_samples["e2"]], axis=-1)

jnp.save(

g_samples = vpipe2(skey, e_post)


Expand Down
76 changes: 74 additions & 2 deletions tests/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@
from jax import random, vmap

from bpd.chains import run_inference_nuts
from bpd.initialization import init_with_truth
from bpd.pipelines.image_ellips import (
get_target_galaxy_params_simple,
get_target_images_single,
get_true_params_from_galaxy_params,
pipeline_image_interim_samples_one_galaxy,
)
from bpd.pipelines.shear_inference import pipeline_shear_inference
from bpd.pipelines.toy_ellips import logtarget as logtarget_toy_ellips
from bpd.pipelines.toy_ellips import pipeline_toy_ellips_samples
from bpd.prior import ellip_mag_prior, sample_synthetic_sheared_ellips_unclipped


@pytest.mark.parametrize("seed", [1234, 4567])
def test_interim_ellipticity_posterior_convergence(seed):
def test_interim_toy_convergence(seed):
"""Check efficiency and convergence of chains for 100 galaxies."""
g1, g2 = 0.02, 0.0
sigma_m = 1e-4
Expand Down Expand Up @@ -74,7 +81,7 @@ def test_interim_ellipticity_posterior_convergence(seed):


@pytest.mark.parametrize("seed", [1234, 4567])
def test_shear_posterior_convergence(seed):
def test_toy_shear_convergence(seed):
g1, g2 = 0.02, 0.0
sigma_m = 1e-4
sigma_e = 1e-3
Expand Down Expand Up @@ -124,3 +131,68 @@ def test_shear_posterior_convergence(seed):

assert ess > 0.5 * 4000
assert jnp.abs(rhat - 1) < 0.01


@pytest.mark.slow
@pytest.mark.parametrize("seed", [1234, 4567])
def test_low_noise_single_galaxy_interim_samples(seed):
lf = 6.0
hlr = 1.0
g1, g2 = 0.02, 0.0
sigma_e = 1e-3
sigma_e_int = 3e-2
n_samples = 500
background = 1.0
slen = 53
fft_size = 256
init_fnc = init_with_truth

rng_key = random.key(seed)
pkey, nkey, gkey = random.split(rng_key, 3)

galaxy_params = get_target_galaxy_params_simple(
pkey, lf=lf, g1=g1, g2=g2, hlr=hlr, shape_noise=sigma_e
)

draw_params = {**galaxy_params}
draw_params["f"] = 10 ** draw_params.pop("lf")
target_image, _ = get_target_images_single(
nkey,
n_samples=1,
single_galaxy_params=draw_params,
background=background,
slen=slen,
)
assert target_image.shape == (1, slen, slen)

true_params = get_true_params_from_galaxy_params(galaxy_params)

pipe1 = partial(
pipeline_image_interim_samples_one_galaxy,
initialization_fnc=init_fnc,
sigma_e_int=sigma_e_int,
n_samples=n_samples,
slen=slen,
fft_size=fft_size,
n_warmup_steps=300,
)
vpipe1 = vmap(jjit(pipe1), (0, 0, None))

# chain initialization
# one galaxy, test convergence, so 4 random seeds
gkey1, gkey2 = random.split(gkey, 2)
keys1 = random.split(gkey1, 4)
keys2 = random.split(gkey2, 4)

init_positions = vmap(init_fnc, (0, None))(keys1, true_params)

samples = vpipe1(keys2, init_positions, target_image)

# check each component
for _, v in samples.items():
assert v.shape == (4, n_samples)
ess = effective_sample_size(v)
rhat = potential_scale_reduction(v)

assert ess > 0.5 * n_samples
assert jnp.abs(rhat - 1) < 0.01