From 3d48d441d06df7fd4e31eefa581ab49dbb6815d5 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> Date: Tue, 14 Jan 2025 16:37:44 -0600 Subject: [PATCH] Refactoring (jan 2025 edition) (#80) * benchmarks have been superseeded, might revive later * might put more random stuff here * all reusable priors and likelihoods go here! * noise is part of drawing? * shear transformations now go here * reusable and most generalized pipelines go here * generating samples of parameters or imgaes * pipelines got consolidated * fix one test * it's useful to keep the simple pipeline that only does ellipticities * fix another test * much less needs to be added see * fix last test * organize scripts * fix experiments * unchange rerun --- bpd/draw.py | 21 ++ bpd/likelihood.py | 67 ++-- bpd/noise.py | 21 -- bpd/pipelines.py | 207 +++++++++++ bpd/pipelines/image_samples.py | 211 ----------- bpd/pipelines/shear_inference.py | 69 ---- bpd/pipelines/toy_ellips.py | 71 ---- bpd/prior.py | 109 +++--- bpd/sample.py | 120 +++++++ bpd/shear.py | 43 +++ bpd/{measure.py => utils.py} | 0 experiments/exp1/get_figures.sh | 2 +- experiments/exp1/get_posteriors.sh | 2 +- .../exp2/run_inference_galaxy_images.py | 4 +- experiments/exp30/figs/contours.pdf | Bin 82675 -> 82675 bytes experiments/exp30/figs/hists.pdf | Bin 11424 -> 11424 bytes experiments/exp30/figs/scatter_shapes.pdf | Bin 133353 -> 133353 bytes experiments/exp30/figs/traces.pdf | Bin 61411 -> 61411 bytes .../exp30/get_image_interim_samples_fixed.py | 15 +- experiments/exp30/get_posteriors.sh | 2 +- experiments/exp31/figs/contours.pdf | Bin 82151 -> 82151 bytes experiments/exp31/figs/hists.pdf | Bin 11620 -> 11620 bytes experiments/exp31/figs/scatter_dxdy.pdf | Bin 135424 -> 135424 bytes experiments/exp31/figs/scatter_shapes.pdf | Bin 135190 -> 135190 bytes experiments/exp31/figs/traces.pdf | Bin 61484 -> 61484 bytes experiments/exp31/get_interim_samples.py | 15 +- experiments/exp31/get_posteriors.sh | 2 +- experiments/exp32/figs/contours.pdf | Bin 81366 -> 81366 bytes experiments/exp32/figs/hists.pdf | Bin 11186 -> 11186 bytes experiments/exp32/figs/hists_flux_hlr.pdf | Bin 42335 -> 42335 bytes experiments/exp32/figs/scatter_dxdy.pdf | Bin 134816 -> 134816 bytes experiments/exp32/figs/scatter_shapes.pdf | Bin 134826 -> 134826 bytes experiments/exp32/figs/traces.pdf | Bin 61567 -> 61567 bytes experiments/exp32/get_interim_samples.py | 12 +- experiments/exp32/get_shear.py | 59 +--- scripts/benchmarks/benchmark1.py | 316 ----------------- scripts/benchmarks/benchmark2.py | 329 ------------------ scripts/benchmarks/benchmark2_7.py | 311 ----------------- scripts/benchmarks/benchmark2_72.py | 266 -------------- scripts/benchmarks/benchmark2_8.py | 308 ---------------- scripts/benchmarks/benchmark_chees1.py | 299 ---------------- ...im_samples.py => get_shear_from_shapes.py} | 4 +- scripts/{ => slurm}/slurm_job.py | 0 .../{ => slurm}/slurm_toy_shear_vectorized.py | 2 +- scripts/toy_shear_vectorized.py | 7 +- tests/test_convergence.py | 14 +- tests/test_shear_inference.py | 7 +- tests/test_shear_trans.py | 4 +- 48 files changed, 527 insertions(+), 2392 deletions(-) delete mode 100644 bpd/noise.py create mode 100644 bpd/pipelines.py delete mode 100644 bpd/pipelines/image_samples.py delete mode 100644 bpd/pipelines/shear_inference.py delete mode 100644 bpd/pipelines/toy_ellips.py create mode 100644 bpd/sample.py create mode 100644 bpd/shear.py rename bpd/{measure.py => utils.py} (100%) delete mode 100755 scripts/benchmarks/benchmark1.py delete mode 100755 scripts/benchmarks/benchmark2.py delete mode 100755 scripts/benchmarks/benchmark2_7.py delete mode 100755 scripts/benchmarks/benchmark2_72.py delete mode 100755 scripts/benchmarks/benchmark2_8.py delete mode 100755 scripts/benchmarks/benchmark_chees1.py rename scripts/{get_shear_from_interim_samples.py => get_shear_from_shapes.py} (92%) rename scripts/{ => slurm}/slurm_job.py (100%) rename scripts/{ => slurm}/slurm_toy_shear_vectorized.py (97%) diff --git a/bpd/draw.py b/bpd/draw.py index 770e745..3e5a48a 100644 --- a/bpd/draw.py +++ b/bpd/draw.py @@ -1,5 +1,9 @@ import galsim +import jax.numpy as jnp import jax_galsim as xgalsim +from jax import random +from jax._src.prng import PRNGKeyArray +from jax.typing import ArrayLike from jax_galsim import GSParams @@ -50,3 +54,20 @@ def draw_gaussian_galsim( gal_conv = galsim.Convolve([gal, psf]) image = gal_conv.drawImage(nx=slen, ny=slen, scale=pixel_scale, offset=(x, y)) return image.array + + +def add_noise( + rng_key: PRNGKeyArray, + x: ArrayLike, + bg: float, + n: int = 1, +): + """Produce `n` independent Gaussian noise realizations of a given image `x`. + + NOTE: This function assumes image is background-subtracted and dominated. + """ + assert isinstance(bg, float) or bg.shape == () + x = x.reshape(1, *x.shape) + x = x.repeat(n, axis=0) + noise = random.normal(rng_key, shape=x.shape) * jnp.sqrt(bg) + return x + noise diff --git a/bpd/likelihood.py b/bpd/likelihood.py index 914306f..efbed1b 100644 --- a/bpd/likelihood.py +++ b/bpd/likelihood.py @@ -2,39 +2,14 @@ import jax.numpy as jnp import jax.scipy as jsp -from jax import Array, grad, vmap +from jax import Array +from jax.scipy import stats from jax.typing import ArrayLike -from bpd.prior import ( - ellip_prior_e1e2, - inv_shear_func1, - inv_shear_func2, - inv_shear_transformation, -) - -_grad_fnc1 = vmap(vmap(grad(inv_shear_func1), in_axes=(0, None)), in_axes=(0, None)) -_grad_fnc2 = vmap(vmap(grad(inv_shear_func2), in_axes=(0, None)), in_axes=(0, None)) -_inv_shear_trans = vmap(inv_shear_transformation, in_axes=(0, None)) - - -def true_ellip_logprior(e_post: Array, g: Array, *, sigma_e: float): - """Implementation of GB's true prior on interim posterior samples of ellipticities.""" - - # jacobian of inverse shear transformation - grad1 = _grad_fnc1(e_post, g) - grad2 = _grad_fnc2(e_post, g) - absjacdet = jnp.abs(grad1[..., 0] * grad2[..., 1] - grad1[..., 1] * grad2[..., 0]) - - # true prior on unsheared ellipticity - e_post_unsheared = _inv_shear_trans(e_post, g) - prior_val = ellip_prior_e1e2(e_post_unsheared, sigma=sigma_e) - - return jnp.log(prior_val) + jnp.log(absjacdet) - def shear_loglikelihood( g: Array, - post_params: dict[str, Array], + post_params: dict[str, Array] | Array, *, logprior: Callable, interim_logprior: Callable, # fixed @@ -44,3 +19,39 @@ def shear_loglikelihood( num = logprior(post_params, g) ratio = jsp.special.logsumexp(num - denom, axis=-1) return ratio.sum() + + +def gaussian_image_loglikelihood( + params: dict[str, Array], + data: Array, + fixed_params: dict[str, Array], + *, + draw_fnc: Callable, + background: float, + free_flux_hlr: bool = True, + free_dxdy: bool = True, +): + _draw_params = {} + + if free_dxdy: + _draw_params["x"] = params["dx"] + fixed_params["x"] + _draw_params["y"] = params["dy"] + fixed_params["y"] + + else: + _draw_params["x"] = fixed_params["x"] + _draw_params["y"] = fixed_params["y"] + + if free_flux_hlr: + _draw_params["f"] = 10 ** params["lf"] + _draw_params["hlr"] = 10 ** params["lhlr"] + + else: + _draw_params["f"] = fixed_params["f"] + _draw_params["hlr"] = fixed_params["hlr"] + + _draw_params["e1"] = params["e1"] + _draw_params["e2"] = params["e2"] + + model = draw_fnc(**_draw_params) + likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(background)) + return jnp.sum(likelihood_pp) diff --git a/bpd/noise.py b/bpd/noise.py deleted file mode 100644 index d8e83f9..0000000 --- a/bpd/noise.py +++ /dev/null @@ -1,21 +0,0 @@ -import jax.numpy as jnp -from jax import random -from jax._src.prng import PRNGKeyArray -from jax.typing import ArrayLike - - -def add_noise( - rng_key: PRNGKeyArray, - x: ArrayLike, - bg: float, - n: int = 1, -): - """Produce `n` independent Gaussian noise realizations of a given image `x`. - - NOTE: This function assumes image is background-subtracted and dominated. - """ - assert isinstance(bg, float) or bg.shape == () - x = x.reshape(1, *x.shape) - x = x.repeat(n, axis=0) - noise = random.normal(rng_key, shape=x.shape) * jnp.sqrt(bg) - return x + noise diff --git a/bpd/pipelines.py b/bpd/pipelines.py new file mode 100644 index 0000000..ae1e4a4 --- /dev/null +++ b/bpd/pipelines.py @@ -0,0 +1,207 @@ +from functools import partial +from typing import Callable + +import jax.numpy as jnp +import jax.scipy as jsp +from jax import Array, jit, random, vmap +from jax._src.prng import PRNGKeyArray +from jax.scipy import stats + +from bpd.chains import run_inference_nuts +from bpd.likelihood import shear_loglikelihood +from bpd.prior import ellip_prior_e1e2, true_ellip_logprior +from bpd.sample import sample_noisy_ellipticities_unclipped + + +def logtarget_shear( + g: Array, *, data: Array | dict[str, Array], loglikelihood: Callable, sigma_g: float +): + loglike = loglikelihood(g, post_params=data) + logprior = stats.norm.logpdf(g, loc=0.0, scale=sigma_g).sum() + return logprior + loglike + + +def pipeline_shear_inference( + rng_key: PRNGKeyArray, + post_params: Array, + init_g: Array, + *, + logprior: Callable, + interim_logprior: Callable, + n_samples: int, + initial_step_size: float, + sigma_g: float = 0.01, + n_warmup_steps: int = 500, + max_num_doublings: int = 2, +): + _loglikelihood = partial( + shear_loglikelihood, logprior=logprior, interim_logprior=interim_logprior + ) + _loglikelihood_jitted = jit(_loglikelihood) + + _logtarget = partial( + logtarget_shear, loglikelihood=_loglikelihood_jitted, sigma_g=sigma_g + ) + + _do_inference = partial( + run_inference_nuts, + data=post_params, + logtarget=_logtarget, + n_samples=n_samples, + n_warmup_steps=n_warmup_steps, + max_num_doublings=max_num_doublings, + initial_step_size=initial_step_size, + ) + + g_samples = _do_inference(rng_key, init_g) + return g_samples + + +def pipeline_shear_inference_simple( + rng_key: PRNGKeyArray, + e_post: Array, + init_g: Array, + *, + sigma_e: float, + sigma_e_int: float, + n_samples: int, + initial_step_size: float, + sigma_g: float = 0.01, + n_warmup_steps: int = 500, + max_num_doublings: int = 2, +): + _logprior = lambda e, g: true_ellip_logprior(e, g, sigma_e=sigma_e) + _interim_logprior = lambda e: jnp.log(ellip_prior_e1e2(e, sigma=sigma_e_int)) + + _loglikelihood = partial( + shear_loglikelihood, logprior=_logprior, interim_logprior=_interim_logprior + ) + _loglikelihood_jitted = jit(_loglikelihood) + + _logtarget = partial( + logtarget_shear, loglikelihood=_loglikelihood_jitted, sigma_g=sigma_g + ) + + _do_inference = partial( + run_inference_nuts, + data=e_post, + logtarget=_logtarget, + n_samples=n_samples, + n_warmup_steps=n_warmup_steps, + max_num_doublings=max_num_doublings, + initial_step_size=initial_step_size, + ) + + return _do_inference(rng_key, init_g) + + +def logtarget_images( + params: dict[str, Array], + data: Array, + *, + fixed_params: dict[str, Array], + logprior_fnc: Callable, + loglikelihood_fnc: Callable, +): + return logprior_fnc(params) + loglikelihood_fnc(params, data, fixed_params) + + +def pipeline_interim_samples_one_galaxy( + rng_key: PRNGKeyArray, + true_params: dict[str, float], + target_image: Array, + fixed_params: dict[str, float], + *, + initialization_fnc: Callable, + logprior: Callable, + loglikelihood: Callable, + n_samples: int = 100, + max_num_doublings: int = 5, + initial_step_size: float = 1e-3, + n_warmup_steps: int = 500, + is_mass_matrix_diagonal: bool = True, +): + # Flux and HLR are fixed to truth and not inferred in this function. + k1, k2 = random.split(rng_key) + + init_position = initialization_fnc(k1, true_params=true_params, data=target_image) + + _logtarget = partial( + logtarget_images, + logprior_fnc=logprior, + loglikelihood_fnc=loglikelihood, + fixed_params=fixed_params, + ) + + _inference_fnc = partial( + run_inference_nuts, + logtarget=_logtarget, + is_mass_matrix_diagonal=is_mass_matrix_diagonal, + n_warmup_steps=n_warmup_steps, + max_num_doublings=max_num_doublings, + initial_step_size=initial_step_size, + n_samples=n_samples, + ) + _run_inference = jit(_inference_fnc) + + interim_samples = _run_inference(k2, init_position, target_image) + return interim_samples + + +def logtarget_toy_ellips( + e_sheared: Array, + *, + data: Array, # renamed from `e_obs` for comptability with `do_inference_nuts` + sigma_m: float, + sigma_e_int: float, +): + e_obs = data + assert e_sheared.shape == (2,) and e_obs.shape == (2,) + + prior = jnp.log(ellip_prior_e1e2(e_sheared, sigma=sigma_e_int)) + likelihood = jnp.sum(jsp.stats.norm.logpdf(e_obs, loc=e_sheared, scale=sigma_m)) + return prior + likelihood + + +def pipeline_toy_ellips( + key: PRNGKeyArray, + *, + g1: float, + g2: float, + sigma_e: float, + sigma_e_int: float, + sigma_m: float, + n_gals: int, + n_samples_per_gal: int, + n_warmup_steps: int = 500, + max_num_doublings: int = 2, +): + k1, k2 = random.split(key) + + true_g = jnp.array([g1, g2]) + + e_obs, e_sheared, _ = sample_noisy_ellipticities_unclipped( + k1, g=true_g, sigma_m=sigma_m, sigma_e=sigma_e, n=n_gals + ) + + _logtarget = partial(logtarget_toy_ellips, sigma_m=sigma_m, sigma_e_int=sigma_e_int) + + keys2 = random.split(k2, n_gals) + _do_inference_jitted = jit( + partial( + run_inference_nuts, + logtarget=_logtarget, + n_samples=n_samples_per_gal, + initial_step_size=max(sigma_e, sigma_m), + max_num_doublings=max_num_doublings, + n_warmup_steps=n_warmup_steps, + ) + ) + _do_inference = vmap(_do_inference_jitted, in_axes=(0, 0, 0)) + + # compile + _ = _do_inference(keys2[:2], e_sheared[:2], e_obs[:2]) + + e_post = _do_inference(keys2, e_sheared, e_obs) + + return e_post, e_obs, e_sheared diff --git a/bpd/pipelines/image_samples.py b/bpd/pipelines/image_samples.py deleted file mode 100644 index c957028..0000000 --- a/bpd/pipelines/image_samples.py +++ /dev/null @@ -1,211 +0,0 @@ -from functools import partial -from typing import Callable - -import jax.numpy as jnp -from jax import Array, jit, random -from jax._src.prng import PRNGKeyArray -from jax.scipy import stats - -from bpd.chains import run_inference_nuts -from bpd.draw import draw_gaussian_galsim -from bpd.noise import add_noise -from bpd.prior import ( - ellip_prior_e1e2, - sample_ellip_prior, - scalar_shear_transformation, -) - - -def sample_target_galaxy_params_simple( - rng_key: PRNGKeyArray, - *, - shape_noise: float, - g1: float = 0.02, - g2: float = 0.0, -): - """Fix parameters except position and ellipticity, which come from a prior. - - * The position is drawn uniformly within a pixel (dither). - * The ellipticity is drawn from Gary's prior given the shape noise. - - """ - dkey, ekey = random.split(rng_key, 2) - - x, y = random.uniform(dkey, shape=(2,), minval=-0.5, maxval=0.5) - e = sample_ellip_prior(ekey, sigma=shape_noise, n=1) - return { - "e1": e[0, 0], - "e2": e[0, 1], - "x": x, - "y": y, - "g1": g1, - "g2": g2, - } - - -def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]): - true_params = {**galaxy_params} - e1, e2 = true_params.pop("e1"), true_params.pop("e2") - g1, g2 = true_params.pop("g1"), true_params.pop("g2") - - e1_prime, e2_prime = scalar_shear_transformation( - jnp.array([e1, e2]), jnp.array([g1, g2]) - ) - true_params["e1"] = e1_prime - true_params["e2"] = e2_prime - - return true_params # don't add back g1,g2 as we are not inferring those in interim posterior - - -# interim prior -def logprior( - params: dict[str, Array], - *, - sigma_e: float, - sigma_x: float = 0.5, # pixels - flux_bds: tuple = (-1.0, 9.0), - hlr_bds: tuple = (-2.0, 1.0), - free_flux_hlr: bool = True, - free_dxdy: bool = True, -) -> Array: - prior = jnp.array(0.0) - - if free_flux_hlr: - f1, f2 = flux_bds - prior += stats.uniform.logpdf(params["lf"], f1, f2 - f1) - - h1, h2 = hlr_bds - prior += stats.uniform.logpdf(params["lhlr"], h1, h2 - h1) - - if free_dxdy: - prior += stats.norm.logpdf(params["dx"], loc=0.0, scale=sigma_x) - prior += stats.norm.logpdf(params["dy"], loc=0.0, scale=sigma_x) - - e1e2 = jnp.stack((params["e1"], params["e2"]), axis=-1) - prior += jnp.log(ellip_prior_e1e2(e1e2, sigma=sigma_e)) - - return prior - - -def loglikelihood( - params: dict[str, Array], - data: Array, - fixed_params: dict[str, Array], - *, - draw_fnc: Callable, - background: float, - free_flux_hlr: bool = True, - free_dxdy: bool = True, -): - _draw_params = {} - - if free_dxdy: - _draw_params["x"] = params["dx"] + fixed_params["x"] - _draw_params["y"] = params["dy"] + fixed_params["y"] - - else: - _draw_params["x"] = fixed_params["x"] - _draw_params["y"] = fixed_params["y"] - - if free_flux_hlr: - _draw_params["f"] = 10 ** params["lf"] - _draw_params["hlr"] = 10 ** params["lhlr"] - - else: - _draw_params["f"] = fixed_params["f"] - _draw_params["hlr"] = fixed_params["hlr"] - - _draw_params["e1"] = params["e1"] - _draw_params["e2"] = params["e2"] - - model = draw_fnc(**_draw_params) - likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(background)) - return jnp.sum(likelihood_pp) - - -def logtarget( - params: dict[str, Array], - data: Array, - *, - fixed_params: dict[str, Array], - logprior_fnc: Callable, - loglikelihood_fnc: Callable, -): - return logprior_fnc(params) + loglikelihood_fnc(params, data, fixed_params) - - -def get_target_images_single( - rng_key: PRNGKeyArray, - *, - single_galaxy_params: dict[str, float], - background: float, - slen: int, - n_samples: int = 1, # single noise realization -): - """Multiple noise realizations of single galaxy (GalSim).""" - noiseless = draw_gaussian_galsim(**single_galaxy_params, slen=slen) - return add_noise(rng_key, noiseless, bg=background, n=n_samples) - - -def get_target_images( - rng_key: PRNGKeyArray, - galaxy_params: dict[str, Array], - *, - background: float, - slen: int, -): - """Single noise realization of multiple galaxies (GalSim).""" - n_gals = galaxy_params["f"].shape[0] - nkeys = random.split(rng_key, n_gals) - - target_images = [] - for ii in range(n_gals): - _params = {k: v[ii].item() for k, v in galaxy_params.items()} - noiseless = draw_gaussian_galsim(**_params, slen=slen) - target_image = add_noise(nkeys[ii], noiseless, bg=background, n=1) - assert target_image.shape == (1, slen, slen) - target_images.append(target_image) - - return jnp.concatenate(target_images, axis=0) - - -def pipeline_interim_samples_one_galaxy( - rng_key: PRNGKeyArray, - true_params: dict[str, float], - target_image: Array, - fixed_params: dict[str, float], - *, - initialization_fnc: Callable, - logprior: Callable, - loglikelihood: Callable, - n_samples: int = 100, - max_num_doublings: int = 5, - initial_step_size: float = 1e-3, - n_warmup_steps: int = 500, - is_mass_matrix_diagonal: bool = True, -): - # Flux and HLR are fixed to truth and not inferred in this function. - k1, k2 = random.split(rng_key) - - init_position = initialization_fnc(k1, true_params=true_params, data=target_image) - - _logtarget = partial( - logtarget, - logprior_fnc=logprior, - loglikelihood_fnc=loglikelihood, - fixed_params=fixed_params, - ) - - _inference_fnc = partial( - run_inference_nuts, - logtarget=_logtarget, - is_mass_matrix_diagonal=is_mass_matrix_diagonal, - n_warmup_steps=n_warmup_steps, - max_num_doublings=max_num_doublings, - initial_step_size=initial_step_size, - n_samples=n_samples, - ) - _run_inference = jit(_inference_fnc) - - interim_samples = _run_inference(k2, init_position, target_image) - return interim_samples diff --git a/bpd/pipelines/shear_inference.py b/bpd/pipelines/shear_inference.py deleted file mode 100644 index ece0e5d..0000000 --- a/bpd/pipelines/shear_inference.py +++ /dev/null @@ -1,69 +0,0 @@ -from functools import partial -from typing import Callable - -import jax.numpy as jnp -from jax import Array, jit -from jax._src.prng import PRNGKeyArray -from jax.scipy import stats - -from bpd.chains import run_inference_nuts -from bpd.likelihood import shear_loglikelihood, true_ellip_logprior -from bpd.prior import ellip_prior_e1e2 - - -def logtarget_density( - g: Array, *, data: Array, loglikelihood: Callable, sigma_g: float = 0.01 -): - loglike = loglikelihood(g, post_params=data) - logprior = stats.norm.logpdf(g, loc=0.0, scale=sigma_g).sum() - return logprior + loglike - - -def _logprior(post_params: dict[str, Array], g: Array, *, sigma_e: float): - e_post = post_params["e1e2"] - return true_ellip_logprior(e_post, g, sigma_e=sigma_e) - - -def _interim_logprior(post_params: dict[str, Array], sigma_e_int: float): - e_post = post_params["e1e2"] - return jnp.log(ellip_prior_e1e2(e_post, sigma=sigma_e_int)) - - -def pipeline_shear_inference_ellipticities( - rng_key: PRNGKeyArray, - e_post: Array, - init_g: Array, - *, - sigma_e: float, - sigma_e_int: float, - n_samples: int, - initial_step_size: float, - sigma_g: float = 0.01, - n_warmup_steps: int = 500, - max_num_doublings: int = 2, -): - # NOTE: jit must be applied without `e_post` in partial! - _loglikelihood = jit( - partial( - shear_loglikelihood, - logprior=partial(_logprior, sigma_e=sigma_e), - interim_logprior=partial(_interim_logprior, sigma_e_int=sigma_e_int), - ) - ) - _logtarget = partial( - logtarget_density, loglikelihood=_loglikelihood, sigma_g=sigma_g - ) - - _do_inference = partial( - run_inference_nuts, - data={"e1e2": e_post}, - logtarget=_logtarget, - n_samples=n_samples, - n_warmup_steps=n_warmup_steps, - max_num_doublings=max_num_doublings, - initial_step_size=initial_step_size, - ) - - g_samples = _do_inference(rng_key, init_g) - - return g_samples diff --git a/bpd/pipelines/toy_ellips.py b/bpd/pipelines/toy_ellips.py deleted file mode 100644 index 22bad6b..0000000 --- a/bpd/pipelines/toy_ellips.py +++ /dev/null @@ -1,71 +0,0 @@ -from functools import partial - -import jax.numpy as jnp -import jax.scipy as jsp -from jax import Array, jit, random, vmap -from jax._src.prng import PRNGKeyArray - -from bpd.chains import run_inference_nuts -from bpd.prior import ( - ellip_prior_e1e2, - sample_noisy_ellipticities_unclipped, -) - - -def logtarget( - e_sheared: Array, - *, - data: Array, # renamed from `e_obs` for comptability with `do_inference_nuts` - sigma_m: float, - sigma_e_int: float, -): - e_obs = data - assert e_sheared.shape == (2,) and e_obs.shape == (2,) - - prior = jnp.log(ellip_prior_e1e2(e_sheared, sigma=sigma_e_int)) - likelihood = jnp.sum(jsp.stats.norm.logpdf(e_obs, loc=e_sheared, scale=sigma_m)) - return prior + likelihood - - -def pipeline_toy_ellips_samples( - key: PRNGKeyArray, - *, - g1: float, - g2: float, - sigma_e: float, - sigma_e_int: float, - sigma_m: float, - n_gals: int, - n_samples_per_gal: int, - n_warmup_steps: int = 500, - max_num_doublings: int = 2, -): - k1, k2 = random.split(key) - - true_g = jnp.array([g1, g2]) - - e_obs, e_sheared, _ = sample_noisy_ellipticities_unclipped( - k1, g=true_g, sigma_m=sigma_m, sigma_e=sigma_e, n=n_gals - ) - - _logtarget = partial(logtarget, sigma_m=sigma_m, sigma_e_int=sigma_e_int) - - keys2 = random.split(k2, n_gals) - _do_inference_jitted = jit( - partial( - run_inference_nuts, - logtarget=_logtarget, - n_samples=n_samples_per_gal, - initial_step_size=max(sigma_e, sigma_m), - max_num_doublings=max_num_doublings, - n_warmup_steps=n_warmup_steps, - ) - ) - _do_inference = vmap(_do_inference_jitted, in_axes=(0, 0, 0)) - - # compile - _ = _do_inference(keys2[:2], e_sheared[:2], e_obs[:2]) - - e_post = _do_inference(keys2, e_sheared, e_obs) - - return e_post, e_obs, e_sheared diff --git a/bpd/prior.py b/bpd/prior.py index 4b438a6..c7395f1 100644 --- a/bpd/prior.py +++ b/bpd/prior.py @@ -1,9 +1,15 @@ import jax.numpy as jnp -from jax import Array, random, vmap -from jax._src.prng import PRNGKeyArray +from jax import Array, grad, vmap from jax.numpy.linalg import norm +from jax.scipy import stats from jax.typing import ArrayLike +from bpd.shear import ( + inv_shear_func1, + inv_shear_func2, + inv_shear_transformation, +) + def ellip_mag_prior(e_mag: ArrayLike, sigma: float) -> ArrayLike: """Prior for the magnitude of the ellipticity with domain (0, 1). @@ -37,79 +43,50 @@ def ellip_prior_e1e2(e1e2: Array, sigma: float) -> ArrayLike: return (1 - e_mag**2) ** 2 * jnp.exp(-(e_mag**2) / (2 * sigma**2)) / _norm -def sample_mag_ellip_prior( - rng_key: PRNGKeyArray, sigma: float, n: int = 1, n_bins: int = 1_000_000 -): - """Sample n points from GB's ellipticity magnitude prior.""" - e_mag_array = jnp.linspace(0, 1, n_bins) - p_array = ellip_mag_prior(e_mag_array, sigma=sigma) - p_array /= p_array.sum() - return random.choice(rng_key, e_mag_array, shape=(n,), p=p_array) - - -def sample_ellip_prior(rng_key: PRNGKeyArray, sigma: float, n: int = 1): - """Sample n ellipticities isotropic components with Gary's prior for magnitude.""" - key1, key2 = random.split(rng_key, 2) - e_mag = sample_mag_ellip_prior(key1, sigma=sigma, n=n) - e_phi = random.uniform(key2, shape=(n,), minval=0, maxval=jnp.pi) - e1 = e_mag * jnp.cos(2 * e_phi) - e2 = e_mag * jnp.sin(2 * e_phi) - return jnp.stack((e1, e2), axis=1) - - -def scalar_shear_transformation(e: Array, g: Array): - """Transform elliptiticies by a fixed shear (scalar version). - - The transformation we used is equation 3.4b in Seitz & Schneider (1997). +_grad_fnc1 = vmap(vmap(grad(inv_shear_func1), in_axes=(0, None)), in_axes=(0, None)) +_grad_fnc2 = vmap(vmap(grad(inv_shear_func2), in_axes=(0, None)), in_axes=(0, None)) +_inv_shear_trans = vmap(inv_shear_transformation, in_axes=(0, None)) - NOTE: This function is meant to be vmapped later. - """ - assert e.shape == (2,) and g.shape == (2,) - - e1, e2 = e - g1, g2 = g - e_comp = e1 + e2 * 1j - g_comp = g1 + g2 * 1j +def true_ellip_logprior(e_post: Array, g: Array, *, sigma_e: float): + """Implementation of GB's true prior on interim posterior samples of ellipticities.""" - e_prime = (e_comp + g_comp) / (1 + g_comp.conjugate() * e_comp) - return jnp.array([e_prime.real, e_prime.imag]) + # jacobian of inverse shear transformation + grad1 = _grad_fnc1(e_post, g) + grad2 = _grad_fnc2(e_post, g) + absjacdet = jnp.abs(grad1[..., 0] * grad2[..., 1] - grad1[..., 1] * grad2[..., 0]) + # true prior on unsheared ellipticity + e_post_unsheared = _inv_shear_trans(e_post, g) + prior_val = ellip_prior_e1e2(e_post_unsheared, sigma=sigma_e) -def scalar_inv_shear_transformation(e: Array, g: Array): - """Same as above but the inverse.""" - assert e.shape == (2,) and g.shape == (2,) - e1, e2 = e - g1, g2 = g + return jnp.log(prior_val) + jnp.log(absjacdet) - e_comp = e1 + e2 * 1j - g_comp = g1 + g2 * 1j - e_prime = (e_comp - g_comp) / (1 - g_comp.conjugate() * e_comp) - return jnp.array([e_prime.real, e_prime.imag]) +def interim_gprops_logprior( + params: dict[str, Array], + *, + sigma_e: float, + sigma_x: float = 0.5, # pixels + flux_bds: tuple = (-1.0, 9.0), + hlr_bds: tuple = (-2.0, 1.0), + free_flux_hlr: bool = True, + free_dxdy: bool = True, +) -> Array: + prior = jnp.array(0.0) + if free_flux_hlr: + f1, f2 = flux_bds + prior += stats.uniform.logpdf(params["lf"], f1, f2 - f1) -# batched -shear_transformation = vmap(scalar_shear_transformation, in_axes=(0, None)) -inv_shear_transformation = vmap(scalar_inv_shear_transformation, in_axes=(0, None)) + h1, h2 = hlr_bds + prior += stats.uniform.logpdf(params["lhlr"], h1, h2 - h1) -# useful for jacobian later -inv_shear_func1 = lambda e, g: scalar_inv_shear_transformation(e, g)[0] -inv_shear_func2 = lambda e, g: scalar_inv_shear_transformation(e, g)[1] + if free_dxdy: + prior += stats.norm.logpdf(params["dx"], loc=0.0, scale=sigma_x) + prior += stats.norm.logpdf(params["dy"], loc=0.0, scale=sigma_x) + e1e2 = jnp.stack((params["e1"], params["e2"]), axis=-1) + prior += jnp.log(ellip_prior_e1e2(e1e2, sigma=sigma_e)) -def sample_noisy_ellipticities_unclipped( - rng_key: PRNGKeyArray, - *, - g: Array, - sigma_m: float, - sigma_e: float, - n: int = 1, -): - """We sample noisy sheared ellipticities from N(e_int + g, sigma_m^2)""" - key1, key2 = random.split(rng_key, 2) - - e_int = sample_ellip_prior(key1, sigma=sigma_e, n=n) - e_sheared = shear_transformation(e_int, g) - e_obs = random.normal(key2, shape=(n, 2)) * sigma_m + e_sheared.reshape(n, 2) - return e_obs, e_sheared, e_int + return prior diff --git a/bpd/sample.py b/bpd/sample.py new file mode 100644 index 0000000..4eef590 --- /dev/null +++ b/bpd/sample.py @@ -0,0 +1,120 @@ +import jax.numpy as jnp +from jax import Array, random +from jax._src.prng import PRNGKeyArray + +from bpd.draw import add_noise, draw_gaussian_galsim +from bpd.prior import ellip_mag_prior +from bpd.shear import scalar_shear_transformation, shear_transformation + + +def sample_mag_ellip_prior( + rng_key: PRNGKeyArray, sigma: float, n: int = 1, n_bins: int = 1_000_000 +): + """Sample n points from GB's ellipticity magnitude prior.""" + e_mag_array = jnp.linspace(0, 1, n_bins) + p_array = ellip_mag_prior(e_mag_array, sigma=sigma) + p_array /= p_array.sum() + return random.choice(rng_key, e_mag_array, shape=(n,), p=p_array) + + +def sample_ellip_prior(rng_key: PRNGKeyArray, sigma: float, n: int = 1): + """Sample n ellipticities isotropic components with Gary's prior for magnitude.""" + key1, key2 = random.split(rng_key, 2) + e_mag = sample_mag_ellip_prior(key1, sigma=sigma, n=n) + e_phi = random.uniform(key2, shape=(n,), minval=0, maxval=jnp.pi) + e1 = e_mag * jnp.cos(2 * e_phi) + e2 = e_mag * jnp.sin(2 * e_phi) + return jnp.stack((e1, e2), axis=1) + + +def sample_noisy_ellipticities_unclipped( + rng_key: PRNGKeyArray, + *, + g: Array, + sigma_m: float, + sigma_e: float, + n: int = 1, +): + """We sample noisy sheared ellipticities from N(e_int + g, sigma_m^2)""" + key1, key2 = random.split(rng_key, 2) + + e_int = sample_ellip_prior(key1, sigma=sigma_e, n=n) + e_sheared = shear_transformation(e_int, g) + e_obs = random.normal(key2, shape=(n, 2)) * sigma_m + e_sheared.reshape(n, 2) + return e_obs, e_sheared, e_int + + +def sample_target_galaxy_params_simple( + rng_key: PRNGKeyArray, + *, + shape_noise: float, + g1: float = 0.02, + g2: float = 0.0, +): + """Fix parameters except position and ellipticity, which come from a prior. + + * The position is drawn uniformly within a pixel (dither). + * The ellipticity is drawn from Gary's prior given the shape noise. + + """ + dkey, ekey = random.split(rng_key, 2) + + x, y = random.uniform(dkey, shape=(2,), minval=-0.5, maxval=0.5) + e = sample_ellip_prior(ekey, sigma=shape_noise, n=1) + return { + "e1": e[0, 0], + "e2": e[0, 1], + "x": x, + "y": y, + "g1": g1, + "g2": g2, + } + + +def get_true_params_from_galaxy_params(galaxy_params: dict[str, Array]): + true_params = {**galaxy_params} + e1, e2 = true_params.pop("e1"), true_params.pop("e2") + g1, g2 = true_params.pop("g1"), true_params.pop("g2") + + e1_prime, e2_prime = scalar_shear_transformation( + jnp.array([e1, e2]), jnp.array([g1, g2]) + ) + true_params["e1"] = e1_prime + true_params["e2"] = e2_prime + + return true_params # don't add back g1,g2 as we are not inferring those in interim posterior + + +def get_target_images_single( + rng_key: PRNGKeyArray, + *, + single_galaxy_params: dict[str, float], + background: float, + slen: int, + n_samples: int = 1, # single noise realization +): + """Multiple noise realizations of single galaxy (GalSim).""" + noiseless = draw_gaussian_galsim(**single_galaxy_params, slen=slen) + return add_noise(rng_key, noiseless, bg=background, n=n_samples) + + +def get_target_images( + rng_key: PRNGKeyArray, + galaxy_params: dict[str, Array], + *, + background: float, + slen: int, +): + """Single noise realization of multiple galaxies (GalSim).""" + n_gals = galaxy_params["f"].shape[0] + nkeys = random.split(rng_key, n_gals) + + target_images = [] + for ii in range(n_gals): + _params = {k: v[ii].item() for k, v in galaxy_params.items()} + noiseless = draw_gaussian_galsim(**_params, slen=slen) + target_image = add_noise(nkeys[ii], noiseless, bg=background, n=1) + assert target_image.shape == (1, slen, slen) + target_images.append(target_image) + + return jnp.concatenate(target_images, axis=0) diff --git a/bpd/shear.py b/bpd/shear.py new file mode 100644 index 0000000..4a4be13 --- /dev/null +++ b/bpd/shear.py @@ -0,0 +1,43 @@ +import jax.numpy as jnp +from jax import Array, vmap + + +def scalar_shear_transformation(e: Array, g: Array): + """Transform elliptiticies by a fixed shear (scalar version). + + The transformation we used is equation 3.4b in Seitz & Schneider (1997). + + NOTE: This function is meant to be vmapped later. + """ + assert e.shape == (2,) and g.shape == (2,) + + e1, e2 = e + g1, g2 = g + + e_comp = e1 + e2 * 1j + g_comp = g1 + g2 * 1j + + e_prime = (e_comp + g_comp) / (1 + g_comp.conjugate() * e_comp) + return jnp.array([e_prime.real, e_prime.imag]) + + +def scalar_inv_shear_transformation(e: Array, g: Array): + """Same as above but the inverse.""" + assert e.shape == (2,) and g.shape == (2,) + e1, e2 = e + g1, g2 = g + + e_comp = e1 + e2 * 1j + g_comp = g1 + g2 * 1j + + e_prime = (e_comp - g_comp) / (1 - g_comp.conjugate() * e_comp) + return jnp.array([e_prime.real, e_prime.imag]) + + +# batched +shear_transformation = vmap(scalar_shear_transformation, in_axes=(0, None)) +inv_shear_transformation = vmap(scalar_inv_shear_transformation, in_axes=(0, None)) + +# useful for jacobian later +inv_shear_func1 = lambda e, g: scalar_inv_shear_transformation(e, g)[0] +inv_shear_func2 = lambda e, g: scalar_inv_shear_transformation(e, g)[1] diff --git a/bpd/measure.py b/bpd/utils.py similarity index 100% rename from bpd/measure.py rename to bpd/utils.py diff --git a/experiments/exp1/get_figures.sh b/experiments/exp1/get_figures.sh index e07ce5d..02e761f 100755 --- a/experiments/exp1/get_figures.sh +++ b/experiments/exp1/get_figures.sh @@ -1,2 +1,2 @@ #!/bin/bash -./make_figures.py 43 +./make_figures.py 44 diff --git a/experiments/exp1/get_posteriors.sh b/experiments/exp1/get_posteriors.sh index 08b087b..66a6dee 100755 --- a/experiments/exp1/get_posteriors.sh +++ b/experiments/exp1/get_posteriors.sh @@ -1,2 +1,2 @@ #!/bin/bash -../../scripts/slurm_toy_shear_vectorized.py 44 toy_shear_44 +../../scripts/slurm/slurm_toy_shear_vectorized.py 44 toy_shear_44 diff --git a/experiments/exp2/run_inference_galaxy_images.py b/experiments/exp2/run_inference_galaxy_images.py index 270f082..6e22709 100755 --- a/experiments/exp2/run_inference_galaxy_images.py +++ b/experiments/exp2/run_inference_galaxy_images.py @@ -15,12 +15,12 @@ from bpd.chains import run_sampling_nuts, run_warmup_nuts from bpd.draw import draw_gaussian from bpd.initialization import init_with_prior -from bpd.pipelines.image_samples import ( +from bpd.prior import ellip_prior_e1e2 +from bpd.sample import ( get_target_images, get_true_params_from_galaxy_params, sample_target_galaxy_params_simple, ) -from bpd.prior import ellip_prior_e1e2 def logprior( diff --git a/experiments/exp30/figs/contours.pdf b/experiments/exp30/figs/contours.pdf index ecec485c21e3d4f8ba34ae0f51fa55ee49f78b21..1a3d9ebd3aaa59b4fa4487e21fa1de7533bb1a8e 100644 GIT binary patch delta 23 ecmey|%KEvLwP6e6b~{!R0}B&_?R)GP%~=3)k_Z<7 delta 23 ecmey|%KEvLwP6e6b~{!BQ&U5u?R)GP%~=3)a|jdw diff --git a/experiments/exp30/figs/hists.pdf b/experiments/exp30/figs/hists.pdf index 5696cebdb93b23e97508c2f819552d1816002d50..6a6d867c081e6cb3cbc35473f92b423277dbc079 100644 GIT binary patch delta 18 acmZ1wxgc`G2Mtye0}B&_&0jTEGXnrgfd@VS delta 18 acmZ1wxgc`G2MtyOQ&U5u&0jTEGXnrgWd}O| diff --git a/experiments/exp30/figs/scatter_shapes.pdf b/experiments/exp30/figs/scatter_shapes.pdf index 90a73176e12a5e96b6e54760481b2180b9d6f207..51ec341827452f250268cbdf79c129ee9b6f5b48 100644 GIT binary patch delta 23 fcmaDkk>lk=j)pCaZ~w5G7+4rvZ2$C!@jEvFg`5i4 delta 23 fcmaDkk>lk=j)pCaZ~w3wn3@_IZvXU$@jEvFg+>a@ diff --git a/experiments/exp30/figs/traces.pdf b/experiments/exp30/figs/traces.pdf index f5b7ab4b8b51694142e4cb8ec2e2f04e20d562ce..354145e3bb9debe7d1c1aab18c3a35ff5b053893 100644 GIT binary patch delta 20 ccmaESpZW29<_+<0SxpQqObj-sysc*f0BU0i5dZ)H delta 20 ccmaESpZW29<_+<0Sq)504UIOZysc*f0BT1G4gdfE diff --git a/experiments/exp30/get_image_interim_samples_fixed.py b/experiments/exp30/get_image_interim_samples_fixed.py index 98190c1..7cf7ec7 100755 --- a/experiments/exp30/get_image_interim_samples_fixed.py +++ b/experiments/exp30/get_image_interim_samples_fixed.py @@ -9,12 +9,12 @@ from bpd.draw import draw_gaussian from bpd.initialization import init_with_truth from bpd.io import save_dataset -from bpd.pipelines.image_samples import ( +from bpd.likelihood import gaussian_image_loglikelihood +from bpd.pipelines import pipeline_interim_samples_one_galaxy +from bpd.prior import interim_gprops_logprior +from bpd.sample import ( get_target_images, get_true_params_from_galaxy_params, - loglikelihood, - logprior, - pipeline_interim_samples_one_galaxy, sample_target_galaxy_params_simple, ) @@ -71,12 +71,15 @@ def main( # setup prior and likelihood _logprior = partial( - logprior, sigma_e=sigma_e_int, free_flux_hlr=False, free_dxdy=False + interim_gprops_logprior, + sigma_e=sigma_e_int, + free_flux_hlr=False, + free_dxdy=False, ) _draw_fnc = partial(draw_gaussian, slen=slen, fft_size=fft_size) _loglikelihood = partial( - loglikelihood, + gaussian_image_loglikelihood, draw_fnc=_draw_fnc, background=background, free_flux_hlr=False, diff --git a/experiments/exp30/get_posteriors.sh b/experiments/exp30/get_posteriors.sh index 00a3e63..4729eab 100755 --- a/experiments/exp30/get_posteriors.sh +++ b/experiments/exp30/get_posteriors.sh @@ -4,4 +4,4 @@ export JAX_ENABLE_X64="True" SEED="43" ./get_image_interim_samples_fixed.py $SEED -../../scripts/get_shear_from_interim_samples.py $SEED exp30_$SEED "e_post_${SEED}.npz" --overwrite +../../scripts/get_shear_from_shapes.py $SEED exp30_$SEED "e_post_${SEED}.npz" --overwrite diff --git a/experiments/exp31/figs/contours.pdf b/experiments/exp31/figs/contours.pdf index fcde8d708f039e4cef713f2785dd786ed4eaf31e..5f9419bf0d0916d862cf845afd73f27ca1c0dec0 100644 GIT binary patch delta 23 fcmaFf$ojmIwP6e6a~oC@Lkko0?Qd)te=`FBbs7mC delta 23 fcmaFf$ojmIwP6e6a~oDe17kze?Qd)te=`FBbestg diff --git a/experiments/exp31/figs/hists.pdf b/experiments/exp31/figs/hists.pdf index 93bb5a442a345df44c051dcd71276f744d1525aa..dfeed00e7d942c4ffd728c4011f08969b9e67689 100644 GIT binary patch delta 18 acmaD7^(1P;0!>yELkko0%}X`+GXnrk!v{tH delta 18 acmaD7^(1P;0!>y!17kze%}X`+GXnrkQwKNz diff --git a/experiments/exp31/figs/scatter_dxdy.pdf b/experiments/exp31/figs/scatter_dxdy.pdf index 02834a457352e9aa5b0d16200e10c334a7400a47..a596aa858d5457d7c1366aaa8f1a5dd5303b5cd3 100644 GIT binary patch delta 25 fcmZoz#L)mmEsR^3SXfz23@uE|wsW#Fedh)MWElqT delta 25 fcmZoz#L)mmEsR^3SXfyN4U7#nvGKj13G8w{Ny&RAvDHe-8+d delta 23 fcmccio8{VXmWC~i>nvFf4NXnVwr{p%RAvDHe?kbD diff --git a/experiments/exp32/figs/hists.pdf b/experiments/exp32/figs/hists.pdf index bfd376cc02741036bfe88974b6ebe83ecfdbe8e1..b7bb7940b266ff45d8cbb61c43b58464bb13f125 100644 GIT binary patch delta 18 ZcmdlKzA1czmAW@7rz4FGo<2!8+o delta 25 hcmZ2Amt)mjj)oS-Ellf}SPczLO-#0LW@7rz4FGo!2z~$n diff --git a/experiments/exp32/figs/traces.pdf b/experiments/exp32/figs/traces.pdf index 57a7cae29b7007c9c72d29175e3780226a416bd5..dd10a31b8cc90e3d0ad4b0392061555adc256a59 100644 GIT binary patch delta 20 ccmezWfcgIe<_*u^vYHqh7#ePV^L8>50C$cGXaE2J delta 20 ccmezWfcgIe<_*u^vKktinwV{V^L8>50C%|xZ~y=R diff --git a/experiments/exp32/get_interim_samples.py b/experiments/exp32/get_interim_samples.py index 86ac204..286de6c 100755 --- a/experiments/exp32/get_interim_samples.py +++ b/experiments/exp32/get_interim_samples.py @@ -10,12 +10,12 @@ from bpd.draw import draw_gaussian from bpd.initialization import init_with_truth from bpd.io import save_dataset -from bpd.pipelines.image_samples import ( +from bpd.likelihood import gaussian_image_loglikelihood +from bpd.pipelines import pipeline_interim_samples_one_galaxy +from bpd.prior import interim_gprops_logprior +from bpd.sample import ( get_target_images, get_true_params_from_galaxy_params, - loglikelihood, - logprior, - pipeline_interim_samples_one_galaxy, sample_target_galaxy_params_simple, ) @@ -106,11 +106,11 @@ def main( # setup prior and likelihood _logprior = partial( - logprior, sigma_e=sigma_e_int, free_flux_hlr=True, free_dxdy=True + interim_gprops_logprior, sigma_e=sigma_e_int, free_flux_hlr=True, free_dxdy=True ) _draw_fnc = partial(draw_gaussian, slen=slen, fft_size=fft_size) _loglikelihood = partial( - loglikelihood, + gaussian_image_loglikelihood, draw_fnc=_draw_fnc, background=background, free_flux_hlr=True, diff --git a/experiments/exp32/get_shear.py b/experiments/exp32/get_shear.py index 5fc5362..3221b76 100755 --- a/experiments/exp32/get_shear.py +++ b/experiments/exp32/get_shear.py @@ -2,28 +2,17 @@ """This file creates toy samples of ellipticities and saves them to .hdf5 file.""" from functools import partial -from typing import Callable import jax import jax.numpy as jnp import typer -from jax import Array, jit -from jax._src.prng import PRNGKeyArray +from jax import Array from jax.scipy import stats from bpd import DATA_DIR -from bpd.chains import run_inference_nuts from bpd.io import load_dataset -from bpd.likelihood import shear_loglikelihood, true_ellip_logprior -from bpd.pipelines.image_samples import logprior - - -def logtarget_density( - g: Array, *, data: Array, loglikelihood: Callable, sigma_g: float = 0.01 -): - loglike = loglikelihood(g, post_params=data) - logprior = stats.norm.logpdf(g, loc=0.0, scale=sigma_g).sum() - return logprior + loglike +from bpd.pipelines import pipeline_shear_inference +from bpd.prior import interim_gprops_logprior, true_ellip_logprior def _logprior( @@ -54,49 +43,11 @@ def _logprior( def _interim_logprior(post_params: dict[str, Array], sigma_e_int: float): # we do not evaluate dxdy as we assume it's the same as the true prior and they cancel - return logprior( + return interim_gprops_logprior( post_params, sigma_e=sigma_e_int, free_flux_hlr=True, free_dxdy=False ) -def pipeline_shear_inference( - rng_key: PRNGKeyArray, - post_params: Array, - init_g: Array, - *, - logprior: Callable, - interim_logprior: Callable, - n_samples: int, - initial_step_size: float, - sigma_g: float = 0.01, - n_warmup_steps: int = 500, - max_num_doublings: int = 2, -): - # NOTE: jit must be applied without `e_post` in partial! - _loglikelihood = jit( - partial( - shear_loglikelihood, logprior=logprior, interim_logprior=interim_logprior - ) - ) - _logtarget = partial( - logtarget_density, loglikelihood=_loglikelihood, sigma_g=sigma_g - ) - - _do_inference = partial( - run_inference_nuts, - data=post_params, - logtarget=_logtarget, - n_samples=n_samples, - n_warmup_steps=n_warmup_steps, - max_num_doublings=max_num_doublings, - initial_step_size=initial_step_size, - ) - - g_samples = _do_inference(rng_key, init_g) - - return g_samples - - def main( seed: int, initial_step_size: float = 1e-3, @@ -147,9 +98,9 @@ def main( g_samples = pipeline_shear_inference( rng_key, post_params, + init_g=true_g, logprior=logprior_fnc, interim_logprior=interim_logprior_fnc, - init_g=true_g, n_samples=n_samples, initial_step_size=initial_step_size, ) diff --git a/scripts/benchmarks/benchmark1.py b/scripts/benchmarks/benchmark1.py deleted file mode 100755 index 08c0f4d..0000000 --- a/scripts/benchmarks/benchmark1.py +++ /dev/null @@ -1,316 +0,0 @@ -#!/usr/bin/env python3 - -"""Here we run a variable number of chains on a single galaxy and noise realization (NUTS).""" - -import datetime -import os - -os.environ["CUDA_VISIBLE_DEVICES"] = "0" - - -import time -from functools import partial -from pathlib import Path - -import blackjax -import galsim -import jax -import jax.numpy as jnp -import jax_galsim as xgalsim -import numpy as np -from jax import jit as jjit -from jax import random, vmap -from jax.scipy import stats - -from bpd.chains import inference_loop -from bpd.measure import get_snr -from bpd.noise import add_noise - -print("devices available:", jax.devices()) - -SCRATCH_DIR = Path("/pscratch/sd/i/imendoza/data/cache_chains") -LOG_FILE = Path(__file__).parent / "log.txt" - - -# GPU preamble -GPU = jax.devices("gpu")[0] - -jax.config.update("jax_default_device", GPU) - - -PIXEL_SCALE = 0.2 -BACKGROUND = 1e4 -SLEN = 53 -PSF_HLR = 0.7 -GSPARAMS = xgalsim.GSParams(minimum_fft_size=256, maximum_fft_size=256) - -LOG_FLUX = 4.5 -HLR = 0.9 -G1 = 0.05 -G2 = 0.0 -X = 0.0 -Y = 0.0 - -TRUE_PARAMS = {"f": LOG_FLUX, "hlr": HLR, "g1": G1, "g2": G2, "x": X, "y": Y} - -# make sure relevant things are in GPU -TRUE_PARAMS_GPU = jax.device_put(TRUE_PARAMS, device=GPU) -BACKGROUND_GPU = jax.device_put(BACKGROUND, device=GPU) -BOUNDS = { - "f": (-1.0, 9.0), - "hlr": (0.01, 5.0), - "g1": (-0.7, 0.7), - "g2": (-0.7, 0.7), - "x": 1, # sigma (in pixels) - "y": 1, # sigma (in pixels) -} -BOUNDS_GPU = jax.device_put(BOUNDS, device=GPU) - - -# run setup -IS_MATRIX_DIAGONAL = False -N_WARMUPS = 500 -MAX_DOUBLINGS = 5 -N_SAMPLES = 1000 -SEED = 42 -TAG = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - -ALL_N_CHAINS = (1, 5, 10, 25, 50, 100, 150, 200) - - -# sample from ball around some dictionary of true params -def sample_ball(rng_key, center_params: dict): - new = {} - keys = random.split(rng_key, len(center_params.keys())) - rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)} - for p in center_params: - centr = center_params[p] - if p == "f": - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.25, maxval=centr + 0.25 - ) - elif p == "hlr": - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.2, maxval=centr + 0.2 - ) - elif p in {"g1", "g2"}: - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.025, maxval=centr + 0.025 - ) - elif p in {"x", "y"}: - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.5, maxval=centr + 0.5 - ) - return new - - -def _draw_gal(): - gal = galsim.Gaussian(flux=10**LOG_FLUX, half_light_radius=HLR) - gal = gal.shift(dx=X, dy=Y) - gal = gal.shear(g1=G1, g2=G2) - - psf = galsim.Gaussian(flux=1.0, half_light_radius=PSF_HLR) - gal_conv = galsim.Convolve([gal, psf]) - image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE) - return image.array - - -def draw_gal(f, hlr, g1, g2, x, y): - # x, y arguments in pixels - gal = xgalsim.Gaussian(flux=10**f, half_light_radius=hlr) - gal = gal.shift(dx=x * PIXEL_SCALE, dy=y * PIXEL_SCALE) - gal = gal.shear(g1=g1, g2=g2) - - psf = xgalsim.Gaussian(flux=1, half_light_radius=PSF_HLR) - gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(GSPARAMS) - image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE) - return image.array - - -def _logprob_fn(params, data): - # prior - prior = jnp.array(0.0, device=GPU) - for p in ("f", "hlr", "g1", "g2"): # uniform priors - b1, b2 = BOUNDS_GPU[p] - prior += stats.uniform.logpdf(params[p], b1, b2 - b1) - - for p in ("x", "y"): # normal - sigma = BOUNDS_GPU[p] - prior += stats.norm.logpdf(params[p], sigma) - - # likelihood - model = draw_gal(**params) - likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(BACKGROUND_GPU)) - likelihood = jnp.sum(likelihood_pp) - - return prior + likelihood - - -def _log_setup(snr: float): - with open(LOG_FILE, "a", encoding="utf-8") as f: - print(file=f) - print( - f"""Running benchmark 1 with configuration as follows. Variable number of chains. - - The sampler used is NUTS with standard warmup. - - TAG: {TAG} - SEED: {SEED} - - Overall sampler configuration (fixed): - max doublings: {MAX_DOUBLINGS} - n_samples: {N_SAMPLES} - n_warmups: {N_WARMUPS} - diagonal matrix: {IS_MATRIX_DIAGONAL} - - galaxy parameters: - LOG_FLUX: {LOG_FLUX} - HLR: {HLR} - G1: {G1} - G2: {G2} - X: {X} - Y: {Y} - - prior bounds: {BOUNDS} - - other parameters: - slen: {SLEN} - psf_hlr: {PSF_HLR} - background: {BACKGROUND} - snr: {snr} - """, - file=f, - ) - - -# vmap only rng_key -def do_warmup(rng_key, init_position: dict, data): - _logdensity = partial(_logprob_fn, data=data) - - warmup = blackjax.window_adaptation( - blackjax.nuts, - _logdensity, - progress_bar=False, - is_mass_matrix_diagonal=IS_MATRIX_DIAGONAL, - max_num_doublings=MAX_DOUBLINGS, - initial_step_size=0.1, - target_acceptance_rate=0.90, - ) - return warmup.run( - rng_key, init_position, N_WARMUPS - ) # (init_states, tuned_params), adapt_info - - -def do_inference(rng_key, init_state, data, step_size: float, inverse_mass_matrix): - _logdensity = partial(_logprob_fn, data=data) - kernel = blackjax.nuts( - _logdensity, - step_size=step_size, - inverse_mass_matrix=inverse_mass_matrix, - max_num_doublings=MAX_DOUBLINGS, - ).step - return inference_loop( - rng_key, init_state, kernel=kernel, n_samples=N_SAMPLES - ) # state, info - - -def main(): - snr = get_snr(_draw_gal(), BACKGROUND) - print("galaxy snr:", snr) - - # get data - _data = add_noise(_draw_gal(), BACKGROUND, rng=np.random.default_rng(SEED), n=1)[0] - data_gpu = jax.device_put(_data, device=GPU) - print("data info:", data_gpu.devices(), type(data_gpu), data_gpu.shape) - - # collect random keys we need - rng_key = random.key(SEED) - rng_key = jax.device_put(rng_key, device=GPU) - - ball_key, warmup_key, sample_key = random.split(rng_key, 3) - - ball_keys = random.split(ball_key, max(ALL_N_CHAINS)) - warmup_keys = random.split(warmup_key, max(ALL_N_CHAINS)) - sample_keys = random.split(sample_key, max(ALL_N_CHAINS)) - assert warmup_keys.shape == (200,) - - # get initial positions for all chains - all_init_positions = vmap(sample_ball, in_axes=(0, None))( - ball_keys, TRUE_PARAMS_GPU - ) - assert all_init_positions["f"].shape == (200,) - - # jit and vmap functions to run chains - # same data, multiple chains - _run_warmup = vmap(jjit(do_warmup), in_axes=(0, 0, None)) - _run_inference = vmap(jjit(do_inference), in_axes=(0, 0, None, 0, 0)) - - # results - results = {n: {} for n in ALL_N_CHAINS} - - for ii, n_chains in enumerate(ALL_N_CHAINS): - print(f"n_chains: {n_chains}") - - _keys1 = warmup_keys[:n_chains] - _keys2 = sample_keys[:n_chains] - _init_positions = {p: q[:n_chains] for p, q in all_init_positions.items()} - - if ii == 0: - # compilation times - t1 = time.time() - (_sts, _tp), _ = jax.block_until_ready( - _run_warmup(_keys1, _init_positions, data_gpu) - ) - t2 = time.time() - results[n_chains]["warmup_comp_time"] = t2 - t1 - - t1 = time.time() - _ = jax.block_until_ready( - _run_inference( - _keys2, _sts, data_gpu, _tp["step_size"], _tp["inverse_mass_matrix"] - ) - ) - t2 = time.time() - results[n_chains]["inference_comp_time"] = t2 - t1 - - # run times - t1 = time.time() - (init_states, tuned_params), adapt_info = jax.block_until_ready( - _run_warmup(_keys1, _init_positions, data_gpu) - ) - t2 = time.time() - results[n_chains]["warmup_run_time"] = t2 - t1 - - t1 = time.time() - states, infos = jax.block_until_ready( - _run_inference( - _keys2, - init_states, - data_gpu, - tuned_params["step_size"], - tuned_params["inverse_mass_matrix"], - ) - ) - t2 = time.time() - results[n_chains]["inference_run_time"] = t2 - t1 - - # save states and info for future reference - results[n_chains]["states"] = states - results[n_chains]["info"] = infos - results[n_chains]["adapt_info"] = adapt_info - results[n_chains]["tuned_params"] = tuned_params - results[n_chains]["data"] = data_gpu - results[n_chains]["init_positions"] = all_init_positions - - filename = f"results_benchmark1_{TAG}.npy" - filepath = SCRATCH_DIR.joinpath(filename) - jnp.save(filepath, results) - - _log_setup(snr) - with open(LOG_FILE, "a", encoding="utf-8") as f: - print(file=f) - print(f"results were saved to {filepath}", file=f) - - -if __name__ == "__main__": - main() diff --git a/scripts/benchmarks/benchmark2.py b/scripts/benchmarks/benchmark2.py deleted file mode 100755 index 2b7b761..0000000 --- a/scripts/benchmarks/benchmark2.py +++ /dev/null @@ -1,329 +0,0 @@ -#!/usr/bin/env python3 - -"""Here we run 4 chains each on a variable number of noise realizations (NUTS).""" - -import datetime -import os - -os.environ["CUDA_VISIBLE_DEVICES"] = "1" - - -import time -from functools import partial -from pathlib import Path - -import blackjax -import galsim -import jax -import jax.numpy as jnp -import jax_galsim as xgalsim -import numpy as np -from jax import jit as jjit -from jax import random, vmap -from jax.scipy import stats - -from bpd.chains import inference_loop -from bpd.measure import get_snr -from bpd.noise import add_noise - -print("devices available:", jax.devices()) - -SCRATCH_DIR = Path("/pscratch/sd/i/imendoza/data/cache_chains") -LOG_FILE = Path(__file__).parent / "log.txt" - - -# GPU preamble -GPU = jax.devices("gpu")[0] - -jax.config.update("jax_default_device", GPU) - - -PIXEL_SCALE = 0.2 -BACKGROUND = 1e4 -SLEN = 53 -PSF_HLR = 0.7 -GSPARAMS = xgalsim.GSParams(minimum_fft_size=256, maximum_fft_size=256) - -LOG_FLUX = 4.5 -HLR = 0.9 -G1 = 0.05 -G2 = 0.0 -X = 0.0 -Y = 0.0 - -TRUE_PARAMS = {"f": LOG_FLUX, "hlr": HLR, "g1": G1, "g2": G2, "x": X, "y": Y} - -# make sure relevant things are in GPU -TRUE_PARAMS_GPU = jax.device_put(TRUE_PARAMS, device=GPU) -BACKGROUND_GPU = jax.device_put(BACKGROUND, device=GPU) -BOUNDS = { - "f": (-1.0, 9.0), - "hlr": (0.01, 5.0), - "g1": (-0.7, 0.7), - "g2": (-0.7, 0.7), - "x": 1, # sigma (in pixels) - "y": 1, # sigma (in pixels) -} -BOUNDS_GPU = jax.device_put(BOUNDS, device=GPU) - - -# run setup -IS_MATRIX_DIAGONAL = False -N_WARMUPS = 500 -MAX_DOUBLINGS = 5 -N_SAMPLES = 1000 -SEED = 42 -TAG = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - -# specific -ALL_N_OBJECTS = (1, 5, 10, 25, 50) # 50 => 200 chains total -N_CHAINS_PER_OBJ = 4 - - -# sample from ball around some dictionary of true params -def sample_ball(rng_key, center_params: dict): - new = {} - keys = random.split(rng_key, len(center_params.keys())) - rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)} - for p in center_params: - centr = center_params[p] - if p == "f": - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.25, maxval=centr + 0.25 - ) - elif p == "hlr": - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.2, maxval=centr + 0.2 - ) - elif p in {"g1", "g2"}: - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.025, maxval=centr + 0.025 - ) - elif p in {"x", "y"}: - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.5, maxval=centr + 0.5 - ) - return new - - -def _draw_gal(): - gal = galsim.Gaussian(flux=10**LOG_FLUX, half_light_radius=HLR) - gal = gal.shift(dx=X, dy=Y) - gal = gal.shear(g1=G1, g2=G2) - - psf = galsim.Gaussian(flux=1.0, half_light_radius=PSF_HLR) - gal_conv = galsim.Convolve([gal, psf]) - image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE) - return image.array - - -def draw_gal(f, hlr, g1, g2, x, y): - # x, y arguments in pixels - gal = xgalsim.Gaussian(flux=10**f, half_light_radius=hlr) - gal = gal.shift(dx=x * PIXEL_SCALE, dy=y * PIXEL_SCALE) - gal = gal.shear(g1=g1, g2=g2) - - psf = xgalsim.Gaussian(flux=1, half_light_radius=PSF_HLR) - gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(GSPARAMS) - image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE) - return image.array - - -def _logprob_fn(params, data): - # prior - prior = jnp.array(0.0, device=GPU) - for p in ("f", "hlr", "g1", "g2"): # uniform priors - b1, b2 = BOUNDS_GPU[p] - prior += stats.uniform.logpdf(params[p], b1, b2 - b1) - - for p in ("x", "y"): # normal - sigma = BOUNDS_GPU[p] - prior += stats.norm.logpdf(params[p], sigma) - - # likelihood - model = draw_gal(**params) - likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(BACKGROUND_GPU)) - likelihood = jnp.sum(likelihood_pp) - - return prior + likelihood - - -def _log_setup(snr: float): - with open(LOG_FILE, "a", encoding='utf-8') as f: - print(file=f) - print( - f"""Running benchmark 2 with configuration as follows. Variable number of chains. - - The sampler used is NUTS with standard warmup. - - TAG: {TAG} - SEED: {SEED} - - Overall sampler configuration (fixed): - max doublings: {MAX_DOUBLINGS} - n_samples: {N_SAMPLES} - n_warmups: {N_WARMUPS} - diagonal matrix: {IS_MATRIX_DIAGONAL} - - galaxy parameters: - LOG_FLUX: {LOG_FLUX} - HLR: {HLR} - G1: {G1} - G2: {G2} - X: {X} - Y: {Y} - - prior bounds: {BOUNDS} - - other parameters: - slen: {SLEN} - psf_hlr: {PSF_HLR} - background: {BACKGROUND} - snr: {snr} - """, - file=f, - ) - - -# vmap only rng_key -def do_warmup(rng_key, init_position: dict, data): - _logdensity = partial(_logprob_fn, data=data) - - warmup = blackjax.window_adaptation( - blackjax.nuts, - _logdensity, - progress_bar=False, - is_mass_matrix_diagonal=IS_MATRIX_DIAGONAL, - max_num_doublings=MAX_DOUBLINGS, - initial_step_size=0.1, - target_acceptance_rate=0.90, - ) - return warmup.run( - rng_key, init_position, N_WARMUPS - ) # (init_states, tuned_params), adapt_info - - -def do_inference(rng_key, init_state, data, step_size: float, inverse_mass_matrix): - _logdensity = partial(_logprob_fn, data=data) - kernel = blackjax.nuts( - _logdensity, - step_size=step_size, - inverse_mass_matrix=inverse_mass_matrix, - max_num_doublings=MAX_DOUBLINGS, - ).step - return inference_loop( - rng_key, init_state, kernel=kernel, n_samples=N_SAMPLES - ) # state, info - - -def main(): - snr = get_snr(_draw_gal(), BACKGROUND) - print("galaxy snr:", snr) - - max_n_objs = max(ALL_N_OBJECTS) - - # get data - _data = add_noise( - _draw_gal(), BACKGROUND, rng=np.random.default_rng(SEED), n=max_n_objs - ) - data_gpu = jax.device_put(_data, device=GPU) - print("data info:", data_gpu.devices(), type(data_gpu), data_gpu.shape) - - # collect random keys we need - rng_key = random.key(SEED) - rng_key = jax.device_put(rng_key, device=GPU) - - ball_key, warmup_key, sample_key = random.split(rng_key, 3) - - ball_keys = random.split(ball_key, max_n_objs * N_CHAINS_PER_OBJ) - warmup_keys = random.split(warmup_key, max_n_objs * N_CHAINS_PER_OBJ) - sample_keys = random.split(sample_key, max_n_objs * N_CHAINS_PER_OBJ) - assert warmup_keys.shape == (200,) - - # get initial positions for all chains - all_init_positions = vmap(sample_ball, in_axes=(0, None))( - ball_keys, TRUE_PARAMS_GPU - ) - assert all_init_positions["f"].shape == (200,) - - # jit and vmap functions to run chains - # vmap twice to have multiple chains on same data over multiple data - _run_warmup = vmap(vmap(jjit(do_warmup), in_axes=(0, 0, None)), in_axes=(0, 0, 0)) - _run_inference = vmap( - vmap(jjit(do_inference), in_axes=(0, 0, None, 0, 0)), in_axes=(0, 0, 0, 0, 0) - ) - - # results - results = {n * N_CHAINS_PER_OBJ: {} for n in ALL_N_OBJECTS} - - for ii, n_obj in enumerate(ALL_N_OBJECTS): - n_chains = n_obj * N_CHAINS_PER_OBJ - print(f"n_obj: {n_obj}") - - _keys1 = warmup_keys[:n_chains].reshape(n_obj, N_CHAINS_PER_OBJ) - _keys2 = sample_keys[:n_chains].reshape(n_obj, N_CHAINS_PER_OBJ) - _init_positions = { - p: q[:n_chains].reshape(n_obj, N_CHAINS_PER_OBJ) - for p, q in all_init_positions.items() - } - _data_ii = data_gpu[:n_obj] - - if ii == 0: - # compilation times - t1 = time.time() - (_sts, _tp), _ = jax.block_until_ready( - _run_warmup(_keys1, _init_positions, _data_ii) - ) - t2 = time.time() - results[n_chains]["warmup_comp_time"] = t2 - t1 - - t1 = time.time() - _ = jax.block_until_ready( - _run_inference( - _keys2, _sts, _data_ii, _tp["step_size"], _tp["inverse_mass_matrix"] - ) - ) - t2 = time.time() - results[n_chains]["inference_comp_time"] = t2 - t1 - - # run times - t1 = time.time() - (init_states, tuned_params), adapt_info = jax.block_until_ready( - _run_warmup(_keys1, _init_positions, _data_ii) - ) - t2 = time.time() - results[n_chains]["warmup_run_time"] = t2 - t1 - - t1 = time.time() - states, infos = jax.block_until_ready( - _run_inference( - _keys2, - init_states, - _data_ii, - tuned_params["step_size"], - tuned_params["inverse_mass_matrix"], - ) - ) - t2 = time.time() - results[n_chains]["inference_run_time"] = t2 - t1 - - # save states and info for future reference - results[n_chains]["states"] = states - results[n_chains]["info"] = infos - results[n_chains]["adapt_info"] = adapt_info - results[n_chains]["tuned_params"] = tuned_params - results[n_chains]["data"] = data_gpu - results[n_chains]["init_positions"] = all_init_positions - - filename = f"results_benchmark2_{TAG}.npy" - filepath = SCRATCH_DIR.joinpath(filename) - jnp.save(filepath, results) - - _log_setup(snr) - with open(LOG_FILE, "a", encoding='utf-8') as f: - print(file=f) - print(f"results were saved to {filepath}", file=f) - - -if __name__ == "__main__": - main() diff --git a/scripts/benchmarks/benchmark2_7.py b/scripts/benchmarks/benchmark2_7.py deleted file mode 100755 index 893aa68..0000000 --- a/scripts/benchmarks/benchmark2_7.py +++ /dev/null @@ -1,311 +0,0 @@ -#!/usr/bin/env python3 - -"""Here we run multiple chains each on one independent noise realization with NUTS.""" - -import datetime -import os - -os.environ["CUDA_VISIBLE_DEVICES"] = "0" - - -import time -from functools import partial -from pathlib import Path - -import blackjax -import galsim -import jax -import jax.numpy as jnp -import jax_galsim as xgalsim -import numpy as np -from jax import jit as jjit -from jax import random, vmap -from jax.scipy import stats - -from bpd.chains import inference_loop -from bpd.measure import get_snr -from bpd.noise import add_noise - -print("devices available:", jax.devices()) - -SCRATCH_DIR = Path("/pscratch/sd/i/imendoza/data/cache_chains") - - -# GPU preamble -GPU = jax.devices("gpu")[0] - -jax.config.update("jax_default_device", GPU) - - -PIXEL_SCALE = 0.2 -BACKGROUND = 1e4 -SLEN = 53 -PSF_HLR = 0.7 -GSPARAMS = xgalsim.GSParams(minimum_fft_size=256, maximum_fft_size=256) - -LOG_FLUX = 4.5 -HLR = 0.9 -G1 = 0.05 -G2 = 0.0 -X = 0.0 -Y = 0.0 - -TRUE_PARAMS = {"f": LOG_FLUX, "hlr": HLR, "g1": G1, "g2": G2, "x": X, "y": Y} - -# make sure relevant things are in GPU -TRUE_PARAMS_GPU = jax.device_put(TRUE_PARAMS, device=GPU) -BACKGROUND_GPU = jax.device_put(BACKGROUND, device=GPU) -BOUNDS = { - "f": (-1.0, 9.0), - "hlr": (0.01, 5.0), - "g1": (-0.7, 0.7), - "g2": (-0.7, 0.7), - "x": 1, # sigma (in pixels) - "y": 1, # sigma (in pixels) -} -BOUNDS_GPU = jax.device_put(BOUNDS, device=GPU) - - -# run setup -N_WARMUPS = 500 -MAX_DOUBLINGS = 5 -N_SAMPLES = 1000 -N_CHAINS = 100 -SEED = 42 -TAG = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - - -# sample from ball around some dictionary of true params -def sample_ball(rng_key, center_params: dict): - new = {} - keys = random.split(rng_key, len(center_params.keys())) - rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)} - for p in center_params: - centr = center_params[p] - if p == "f": - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.25, maxval=centr + 0.25 - ) - elif p == "hlr": - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.2, maxval=centr + 0.2 - ) - elif p in {"g1", "g2"}: - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.025, maxval=centr + 0.025 - ) - elif p in {"x", "y"}: - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.5, maxval=centr + 0.5 - ) - return new - - -def _draw_gal(): - gal = galsim.Gaussian(flux=10**LOG_FLUX, half_light_radius=HLR) - gal = gal.shift(dx=X, dy=Y) - gal = gal.shear(g1=G1, g2=G2) - - psf = galsim.Gaussian(flux=1.0, half_light_radius=PSF_HLR) - gal_conv = galsim.Convolve([gal, psf]) - image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE) - return image.array - - -def draw_gal(f, hlr, g1, g2, x, y): - # x, y arguments in pixels - gal = xgalsim.Gaussian(flux=10**f, half_light_radius=hlr) - gal = gal.shift(dx=x * PIXEL_SCALE, dy=y * PIXEL_SCALE) - gal = gal.shear(g1=g1, g2=g2) - - psf = xgalsim.Gaussian(flux=1, half_light_radius=PSF_HLR) - gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(GSPARAMS) - image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE) - return image.array - - -def _logprob_fn(params, data): - # prior - prior = jnp.array(0.0, device=GPU) - for p in ("f", "hlr", "g1", "g2"): # uniform priors - b1, b2 = BOUNDS_GPU[p] - prior += stats.uniform.logpdf(params[p], b1, b2 - b1) - - for p in ("x", "y"): # normal - sigma = BOUNDS_GPU[p] - prior += stats.norm.logpdf(params[p], sigma) - - # likelihood - model = draw_gal(**params) - likelihood = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(BACKGROUND_GPU)) - - return jnp.sum(prior) + jnp.sum(likelihood) - - -LOG_FILE = Path(__file__).parent / "log.txt" - - -def _log_setup(snr: float): - with open(LOG_FILE, "a", encoding="utf-8") as f: - print(file=f) - print( - f"""Running benchmark 2.7 with configuration as follows - Single galaxy with different noise realizations over multiple chains. - The sampler used is NUTS with standard warmup. - - TAG: {TAG} - - Overall configuration: - seed: {SEED} - max doublings: {MAX_DOUBLINGS} - n_samples: {N_SAMPLES} - n_chains: {N_CHAINS} - n_warmups: {N_WARMUPS} - - galaxy parameters: - LOG_FLUX: {LOG_FLUX} - HLR: {HLR} - G1: {G1} - G2: {G2} - X: {X} - Y: {Y} - - prior bounds: {BOUNDS} - - other parameters: - slen: {SLEN} - psf_hlr: {PSF_HLR} - background: {BACKGROUND} - snr: {snr} - """, - file=f, - ) - - -# vmap only rng_key -def do_warmup(rng_key, init_position: dict, data): - _logdensity = partial(_logprob_fn, data=data) - - warmup = blackjax.window_adaptation( - blackjax.nuts, - _logdensity, - progress_bar=False, - is_mass_matrix_diagonal=False, - max_num_doublings=MAX_DOUBLINGS, - initial_step_size=0.1, - target_acceptance_rate=0.90, - ) - return warmup.run( - rng_key, init_position, N_WARMUPS - ) # (init_states, tuned_params), adapt_info - - -def do_inference(rng_key, init_state, data, step_size: float, inverse_mass_matrix): - _logdensity = partial(_logprob_fn, data=data) - kernel = blackjax.nuts( - _logdensity, - step_size=step_size, - inverse_mass_matrix=inverse_mass_matrix, - max_num_doublings=MAX_DOUBLINGS, - ).step - return inference_loop( - rng_key, init_state, kernel=kernel, n_samples=N_SAMPLES - ) # state, info - - -def main(): - snr = get_snr(_draw_gal(), BACKGROUND) - print("galaxy snr:", snr) - _log_setup(snr) - - # get data - data = add_noise( - _draw_gal(), BACKGROUND, rng=np.random.default_rng(SEED), n=N_CHAINS - ) - data_gpu = jax.device_put(data, device=GPU) - print("data info:", data_gpu.devices(), type(data_gpu), data_gpu.shape) - - # collect random keys we need - rng_key = random.key(SEED) - rng_key = jax.device_put(rng_key, device=GPU) - - ball_key, warmup_key, sample_key = random.split(rng_key, 3) - del rng_key - - ball_keys = random.split(ball_key, N_CHAINS) - warmup_keys = random.split(warmup_key, N_CHAINS) - sample_keys = random.split(sample_key, N_CHAINS) - - # get initial positions for all chains - all_init_positions = vmap(sample_ball, in_axes=(0, None))( - ball_keys, TRUE_PARAMS_GPU - ) - - # jit and vmap functions to run chains - _run_warmup = vmap(jjit(do_warmup), in_axes=(0, 0, 0)) - _run_inference = vmap(jjit(do_inference), in_axes=(0, 0, 0, 0, 0)) - - # results - results = {} - - # compilation times - t1 = time.time() - (_init_states, _tuned_params), _ = jax.block_until_ready( - _run_warmup(warmup_keys, all_init_positions, data_gpu) - ) - t2 = time.time() - results["warmup_comp_time"] = t2 - t1 - - t1 = time.time() - _ = jax.block_until_ready( - _run_inference( - sample_keys, - _init_states, - data_gpu, - _tuned_params["step_size"], - _tuned_params["inverse_mass_matrix"], - ) - ) - t2 = time.time() - results["inference_comp_time"] = t2 - t1 - - # run times - t1 = time.time() - (init_states, tuned_params), adapt_info = jax.block_until_ready( - _run_warmup(warmup_keys, all_init_positions, data_gpu) - ) - t2 = time.time() - results["warmup_run_time"] = t2 - t1 - - t1 = time.time() - states, infos = jax.block_until_ready( - _run_inference( - sample_keys, - init_states, - data_gpu, - tuned_params["step_size"], - tuned_params["inverse_mass_matrix"], - ) - ) - t2 = time.time() - results["inference_run_time"] = t2 - t1 - - # save states and info for future reference - results["states"] = states - results["info"] = infos - results["adapt_info"] = adapt_info - results["tuned_params"] = tuned_params - results["data"] = data - results["init_positions"] = all_init_positions - - filename = f"results_benchmark-v2_7_{TAG}.npy" - filepath = SCRATCH_DIR.joinpath(filename) - jnp.save(filepath, results) - - with open(LOG_FILE, "a", encoding="utf-8") as f: - print(file=f) - print(f"results were saved to {filepath}", file=f) - - -if __name__ == "__main__": - main() diff --git a/scripts/benchmarks/benchmark2_72.py b/scripts/benchmarks/benchmark2_72.py deleted file mode 100755 index 367e8bc..0000000 --- a/scripts/benchmarks/benchmark2_72.py +++ /dev/null @@ -1,266 +0,0 @@ -#!/usr/bin/env python3 - -"""Here we investigate synchronized divergences between chains run in parallel (during warmup). - -We slightly change the gaalxy compared to benchmark2_7 - -""" - -import os - -os.environ["CUDA_VISIBLE_DEVICES"] = "0" - - -import time -from functools import partial -from pathlib import Path - -import blackjax -import galsim -import jax -import jax.numpy as jnp -import jax_galsim as xgalsim -import numpy as np -from jax import jit as jjit -from jax import random, vmap -from jax.scipy import stats - -from bpd.chains import inference_loop -from bpd.measure import get_snr -from bpd.noise import add_noise - -print("devices available:", jax.devices()) - -SCRATCH_DIR = Path("/pscratch/sd/i/imendoza/data/cache_chains") - - -# GPU preamble -GPU = jax.devices("gpu")[0] - -jax.config.update("jax_default_device", GPU) - - -PIXEL_SCALE = 0.2 -BACKGROUND = 1e4 -SLEN = 53 -PSF_HLR = 0.7 -GSPARAMS = xgalsim.GSParams(minimum_fft_size=256, maximum_fft_size=256) - -LOG_FLUX = 4.8 -HLR = 1.0 -G1 = -0.05 -G2 = 0.0 -X = 0.0 -Y = 0.0 - -TRUE_PARAMS = {"f": LOG_FLUX, "hlr": HLR, "g1": G1, "g2": G2, "x": X, "y": Y} - -# make sure relevant things are in GPU -TRUE_PARAMS_GPU = jax.device_put(TRUE_PARAMS, device=GPU) -BACKGROUND_GPU = jax.device_put(BACKGROUND, device=GPU) -BOUNDS = { - "f": (-1.0, 9.0), - "hlr": (0.01, 5.0), - "g1": (-0.7, 0.7), - "g2": (-0.7, 0.7), - "x": 1, # sigma (in pixels) - "y": 1, # sigma (in pixels) -} -BOUNDS_GPU = jax.device_put(BOUNDS, device=GPU) - - -# sample from ball around some dictionary of true params -def sample_ball(rng_key, center_params: dict): - new = {} - keys = random.split(rng_key, len(center_params.keys())) - rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)} - for p in center_params: - centr = center_params[p] - if p == "f": - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.25, maxval=centr + 0.25 - ) - elif p == "hlr": - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.2, maxval=centr + 0.2 - ) - elif p in {"g1", "g2"}: - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.025, maxval=centr + 0.025 - ) - elif p in {"x", "y"}: - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.5, maxval=centr + 0.5 - ) - return new - - -def _draw_gal(): - gal = galsim.Gaussian(flux=10**LOG_FLUX, half_light_radius=HLR) - gal = gal.shift(dx=X, dy=Y) - gal = gal.shear(g1=G1, g2=G2) - - psf = galsim.Gaussian(flux=1.0, half_light_radius=PSF_HLR) - gal_conv = galsim.Convolve([gal, psf]) - image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE) - return image.array - - -def draw_gal(f, hlr, g1, g2, x, y): - # x, y arguments in pixels - gal = xgalsim.Gaussian(flux=10**f, half_light_radius=hlr) - gal = gal.shift(dx=x * PIXEL_SCALE, dy=y * PIXEL_SCALE) - gal = gal.shear(g1=g1, g2=g2) - - psf = xgalsim.Gaussian(flux=1, half_light_radius=PSF_HLR) - gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(GSPARAMS) - image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE) - return image.array - - -def _logprob_fn(params, data): - # prior - prior = jnp.array(0.0, device=GPU) - for p in ("f", "hlr", "g1", "g2"): # uniform priors - b1, b2 = BOUNDS_GPU[p] - prior += stats.uniform.logpdf(params[p], b1, b2 - b1) - - for p in ("x", "y"): # normal - sigma = BOUNDS_GPU[p] - prior += stats.norm.logpdf(params[p], sigma) - - # likelihood - model = draw_gal(**params) - likelihood = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(BACKGROUND_GPU)) - - return jnp.sum(prior) + jnp.sum(likelihood) - - -# run setup -N_WARMUPS = 500 -MAX_DOUBLINGS = 5 -N_SAMPLES = 100 -N_CHAINS = 10 -SEED = 43 - - -# vmap only rng_key -def do_warmup(rng_key, init_position: dict, data): - _logdensity = partial(_logprob_fn, data=data) - - warmup = blackjax.window_adaptation( - blackjax.nuts, - _logdensity, - progress_bar=False, - is_mass_matrix_diagonal=False, - max_num_doublings=MAX_DOUBLINGS, - initial_step_size=0.1, - target_acceptance_rate=0.90, - ) - return warmup.run( - rng_key, init_position, N_WARMUPS - ) # (init_states, tuned_params), adapt_info - - -def do_inference(rng_key, init_state, data, step_size: float, inverse_mass_matrix): - _logdensity = partial(_logprob_fn, data=data) - kernel = blackjax.nuts( - _logdensity, - step_size=step_size, - inverse_mass_matrix=inverse_mass_matrix, - max_num_doublings=MAX_DOUBLINGS, - ).step - return inference_loop( - rng_key, init_state, kernel=kernel, n_samples=N_SAMPLES - ) # state, info - - -def main(): - print("galaxy snr:", get_snr(_draw_gal(), BACKGROUND)) - - # get data - data = add_noise( - _draw_gal(), BACKGROUND, rng=np.random.default_rng(SEED), n=N_CHAINS - ) - data_gpu = jax.device_put(data, device=GPU) - print("data info:", data_gpu.devices(), type(data_gpu), data_gpu.shape) - - # collect random keys we need - rng_key = random.key(SEED) - rng_key = jax.device_put(rng_key, device=GPU) - - ball_key, warmup_key, sample_key = random.split(rng_key, 3) - - ball_keys = random.split(ball_key, N_CHAINS) - warmup_keys = random.split(warmup_key, N_CHAINS) - sample_keys = random.split(sample_key, N_CHAINS) - - # get initial positions for all chains - all_init_positions = vmap(sample_ball, in_axes=(0, None))( - ball_keys, TRUE_PARAMS_GPU - ) - - # jit and vmap functions to run chains - _run_warmup = vmap(jjit(do_warmup), in_axes=(0, 0, 0)) - _run_inference = vmap(jjit(do_inference), in_axes=(0, 0, 0, 0, 0)) - - # results - results = {} - - # compilation times - t1 = time.time() - (_init_states, _tuned_params), _ = jax.block_until_ready( - _run_warmup(warmup_keys, all_init_positions, data_gpu) - ) - t2 = time.time() - results["warmup_comp_time"] = t2 - t1 - - t1 = time.time() - _ = jax.block_until_ready( - _run_inference( - sample_keys, - _init_states, - data_gpu, - _tuned_params["step_size"], - _tuned_params["inverse_mass_matrix"], - ) - ) - t2 = time.time() - results["inference_comp_time"] = t2 - t1 - - # run times - t1 = time.time() - (init_states, tuned_params), adapt_info = jax.block_until_ready( - _run_warmup(warmup_keys, all_init_positions, data_gpu) - ) - t2 = time.time() - results["warmup_run_time"] = t2 - t1 - - t1 = time.time() - states, infos = jax.block_until_ready( - _run_inference( - sample_keys, - init_states, - data_gpu, - tuned_params["step_size"], - tuned_params["inverse_mass_matrix"], - ) - ) - t2 = time.time() - results["inference_run_time"] = t2 - t1 - - # save states and info for future reference - results["states"] = states - results["info"] = infos - results["adapt_info"] = adapt_info - results["tuned_params"] = tuned_params - results["data"] = data - results["init_positions"] = all_init_positions - - filename = f"results_benchmark-v2_72_{N_CHAINS}_{SEED}.npy" - filepath = SCRATCH_DIR.joinpath(filename) - jnp.save(filepath, results) - - -if __name__ == "__main__": - main() diff --git a/scripts/benchmarks/benchmark2_8.py b/scripts/benchmarks/benchmark2_8.py deleted file mode 100755 index 2a1c5eb..0000000 --- a/scripts/benchmarks/benchmark2_8.py +++ /dev/null @@ -1,308 +0,0 @@ -#!/usr/bin/env python3 - -"""Here we run multiple chains each on one galaxy, same noise realization. - -To get a sense of efficiency. - -""" - -import os - -os.environ["CUDA_VISIBLE_DEVICES"] = "1" - - -import time -from functools import partial -from pathlib import Path - -import blackjax -import galsim -import jax -import jax.numpy as jnp -import jax_galsim as xgalsim -import numpy as np -from jax import jit as jjit -from jax import random, vmap -from jax.scipy import stats - -from bpd.chains import inference_loop -from bpd.measure import get_snr -from bpd.noise import add_noise - -print("devices available:", jax.devices()) - -SCRATCH_DIR = Path("/pscratch/sd/i/imendoza/data/cache_chains") - - -# GPU preamble -GPU = jax.devices("gpu")[0] - -jax.config.update("jax_default_device", GPU) - - -PIXEL_SCALE = 0.2 -BACKGROUND = 1e4 -SLEN = 53 -PSF_HLR = 0.7 -GSPARAMS = xgalsim.GSParams(minimum_fft_size=256, maximum_fft_size=256) - -LOG_FLUX = 4.5 -HLR = 0.9 -G1 = 0.05 -G2 = 0.0 -X = 0.0 -Y = 0.0 - -TRUE_PARAMS = {"f": LOG_FLUX, "hlr": HLR, "g1": G1, "g2": G2, "x": X, "y": Y} - -# make sure relevant things are in GPU -TRUE_PARAMS_GPU = jax.device_put(TRUE_PARAMS, device=GPU) -BACKGROUND_GPU = jax.device_put(BACKGROUND, device=GPU) -BOUNDS = { - "f": (-1.0, 9.0), - "hlr": (0.01, 5.0), - "g1": (-0.7, 0.7), - "g2": (-0.7, 0.7), - "x": 1, # sigma (in pixels) - "y": 1, # sigma (in pixels) -} -BOUNDS_GPU = jax.device_put(BOUNDS, device=GPU) - - -# sample from ball around some dictionary of true params -def sample_ball(rng_key, center_params: dict): - new = {} - keys = random.split(rng_key, len(center_params.keys())) - rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)} - for p in center_params: - centr = center_params[p] - if p == "f": - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.25, maxval=centr + 0.25 - ) - elif p == "hlr": - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.2, maxval=centr + 0.2 - ) - elif p in {"g1", "g2"}: - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.025, maxval=centr + 0.025 - ) - elif p in {"x", "y"}: - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.5, maxval=centr + 0.5 - ) - return new - - -def _draw_gal(): - gal = galsim.Gaussian(flux=10**LOG_FLUX, half_light_radius=HLR) - gal = gal.shift(dx=X, dy=Y) - gal = gal.shear(g1=G1, g2=G2) - - psf = galsim.Gaussian(flux=1.0, half_light_radius=PSF_HLR) - gal_conv = galsim.Convolve([gal, psf]) - image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE) - return image.array - - -def draw_gal(f, hlr, g1, g2, x, y): - # x, y arguments in pixels - gal = xgalsim.Gaussian(flux=10**f, half_light_radius=hlr) - gal = gal.shift(dx=x * PIXEL_SCALE, dy=y * PIXEL_SCALE) - gal = gal.shear(g1=g1, g2=g2) - - psf = xgalsim.Gaussian(flux=1, half_light_radius=PSF_HLR) - gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(GSPARAMS) - image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE) - return image.array - - -def _logprob_fn(params, data): - # prior - prior = jnp.array(0.0, device=GPU) - for p in ("f", "hlr", "g1", "g2"): # uniform priors - b1, b2 = BOUNDS_GPU[p] - prior += stats.uniform.logpdf(params[p], b1, b2 - b1) - - for p in ("x", "y"): # normal - sigma = BOUNDS_GPU[p] - prior += stats.norm.logpdf(params[p], sigma) - - # likelihood - model = draw_gal(**params) - likelihood = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(BACKGROUND_GPU)) - - return jnp.sum(prior) + jnp.sum(likelihood) - - -# run setup -N_WARMUPS = 500 -MAX_DOUBLINGS = 1 -N_SAMPLES = 1000 -N_CHAINS = 100 -SEED = 42 -TAG = "md=1" - - -LOG_FILE = Path(__file__).parent / "log.txt" - - -def _log_setup(snr: float): - with open(LOG_FILE, "a", encoding="utf-8") as f: - print(file=f) - print( - f"""Running benchmark 2.8 with configuration as follows - Single galaxy with one noise realizations over multiple chains. - - Overall configuration: - seed: {SEED} - max doublings: {MAX_DOUBLINGS} - n_samples: 100 - n_chains: 10 - - galaxy parameters: - LOG_FLUX: {LOG_FLUX} - HLR: {HLR} - G1: {G1} - G2: {G2} - X: {X} - Y: {Y} - - prior bounds: {BOUNDS} - - other parameters: - slen: {SLEN} - psf_hlr: {PSF_HLR} - background: {BACKGROUND} - snr: {snr} - """, - file=f, - ) - - -# vmap only rng_key -def do_warmup(rng_key, init_position: dict, data): - _logdensity = partial(_logprob_fn, data=data) - - warmup = blackjax.window_adaptation( - blackjax.nuts, - _logdensity, - progress_bar=False, - is_mass_matrix_diagonal=False, - max_num_doublings=MAX_DOUBLINGS, - initial_step_size=0.1, - target_acceptance_rate=0.90, - ) - return warmup.run( - rng_key, init_position, N_WARMUPS - ) # (init_states, tuned_params), adapt_info - - -def do_inference(rng_key, init_state, data, step_size: float, inverse_mass_matrix): - _logdensity = partial(_logprob_fn, data=data) - kernel = blackjax.nuts( - _logdensity, - step_size=step_size, - inverse_mass_matrix=inverse_mass_matrix, - max_num_doublings=MAX_DOUBLINGS, - ).step - return inference_loop( - rng_key, init_state, kernel=kernel, n_samples=N_SAMPLES - ) # state, info - - -def main(): - snr = get_snr(_draw_gal(), BACKGROUND) - print("galaxy snr:", snr) - _log_setup(snr) - - # get data - data = add_noise(_draw_gal(), BACKGROUND, rng=np.random.default_rng(SEED), n=1)[0] - data_gpu = jax.device_put(data, device=GPU) - print("data info:", data_gpu.devices(), type(data_gpu), data_gpu.shape) - - # collect random keys we need - rng_key = random.key(SEED) - rng_key = jax.device_put(rng_key, device=GPU) - - ball_key, warmup_key, sample_key = random.split(rng_key, 3) - del rng_key - - ball_keys = random.split(ball_key, N_CHAINS) - warmup_keys = random.split(warmup_key, N_CHAINS) - sample_keys = random.split(sample_key, N_CHAINS) - - # get initial positions for all chains - all_init_positions = vmap(sample_ball, in_axes=(0, None))( - ball_keys, TRUE_PARAMS_GPU - ) - - # jit and vmap functions to run chains - _run_warmup = vmap(jjit(do_warmup), in_axes=(0, 0, None)) - _run_inference = vmap(jjit(do_inference), in_axes=(0, 0, None, 0, 0)) - - # results - results = {} - - # compilation times - t1 = time.time() - (_init_states, _tuned_params), _ = jax.block_until_ready( - _run_warmup(warmup_keys, all_init_positions, data_gpu) - ) - t2 = time.time() - results["warmup_comp_time"] = t2 - t1 - - t1 = time.time() - _ = jax.block_until_ready( - _run_inference( - sample_keys, - _init_states, - data_gpu, - _tuned_params["step_size"], - _tuned_params["inverse_mass_matrix"], - ) - ) - t2 = time.time() - results["inference_comp_time"] = t2 - t1 - - # run times - t1 = time.time() - (init_states, tuned_params), adapt_info = jax.block_until_ready( - _run_warmup(warmup_keys, all_init_positions, data_gpu) - ) - t2 = time.time() - results["warmup_run_time"] = t2 - t1 - - t1 = time.time() - states, infos = jax.block_until_ready( - _run_inference( - sample_keys, - init_states, - data_gpu, - tuned_params["step_size"], - tuned_params["inverse_mass_matrix"], - ) - ) - t2 = time.time() - results["inference_run_time"] = t2 - t1 - - # save states and info for future reference - results["states"] = states - results["info"] = infos - results["adapt_info"] = adapt_info - results["tuned_params"] = tuned_params - results["data"] = data - results["init_positions"] = all_init_positions - - filename = f"results_benchmark-v2_8_{N_CHAINS}_{SEED}_{TAG}.npy" - filepath = SCRATCH_DIR.joinpath(filename) - jnp.save(filepath, results) - - with open(LOG_FILE, "a", encoding="utf-8") as f: - print(file=f) - print(f"results were saved to {filepath}", file=f) - - -if __name__ == "__main__": - main() diff --git a/scripts/benchmarks/benchmark_chees1.py b/scripts/benchmarks/benchmark_chees1.py deleted file mode 100755 index 601fddd..0000000 --- a/scripts/benchmarks/benchmark_chees1.py +++ /dev/null @@ -1,299 +0,0 @@ -#!/usr/bin/env python3 - -"""Here we run a variable number of chains on a single galaxy and noise realization (NUTS).""" - -import datetime -import os - -os.environ["CUDA_VISIBLE_DEVICES"] = "0" - - -import time -from functools import partial -from pathlib import Path - -import blackjax -import galsim -import jax -import jax.numpy as jnp -import jax_galsim as xgalsim -import numpy as np -import optax -from jax import random, vmap -from jax.scipy import stats - -from bpd.chains import inference_loop -from bpd.measure import get_snr -from bpd.noise import add_noise - -print("devices available:", jax.devices()) - -SCRATCH_DIR = Path("/pscratch/sd/i/imendoza/data/cache_chains") - - -# GPU preamble -GPU = jax.devices("gpu")[0] - -jax.config.update("jax_default_device", GPU) - -LOG_FILE = Path(__file__).parent / "log.txt" - - -PIXEL_SCALE = 0.2 -BACKGROUND = 1e4 -SLEN = 53 -PSF_HLR = 0.7 -GSPARAMS = xgalsim.GSParams(minimum_fft_size=256, maximum_fft_size=256) - -LOG_FLUX = 4.5 -HLR = 0.9 -G1 = 0.05 -G2 = 0.0 -X = 0.0 -Y = 0.0 - -TRUE_PARAMS = {"f": LOG_FLUX, "hlr": HLR, "g1": G1, "g2": G2, "x": X, "y": Y} - -# make sure relevant things are in GPU -TRUE_PARAMS_GPU = jax.device_put(TRUE_PARAMS, device=GPU) -BACKGROUND_GPU = jax.device_put(BACKGROUND, device=GPU) -BOUNDS = { - "f": (-1.0, 9.0), - "hlr": (0.01, 5.0), - "g1": (-0.7, 0.7), - "g2": (-0.7, 0.7), - "x": 1, # sigma (in pixels) - "y": 1, # sigma (in pixels) -} -BOUNDS_GPU = jax.device_put(BOUNDS, device=GPU) - - -# run setup -IS_MATRIX_DIAGONAL = True -N_WARMUPS = 500 -N_SAMPLES = 1000 -SEED = 42 -TAG = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - -# chees setup -LR = 1e-3 -INIT_STEP_SIZE = 0.1 - -ALL_N_CHAINS = (1, 5, 10, 25, 50, 100, 150, 200, 300, 500) - - -# sample from ball around some dictionary of true params -def sample_ball(rng_key, center_params: dict): - new = {} - keys = random.split(rng_key, len(center_params.keys())) - rng_key_dict = {p: k for p, k in zip(center_params, keys, strict=False)} - for p in center_params: - centr = center_params[p] - if p == "f": - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.25, maxval=centr + 0.25 - ) - elif p == "hlr": - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.2, maxval=centr + 0.2 - ) - elif p in {"g1", "g2"}: - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.025, maxval=centr + 0.025 - ) - elif p in {"x", "y"}: - new[p] = random.uniform( - rng_key_dict[p], shape=(), minval=centr - 0.5, maxval=centr + 0.5 - ) - return new - - -def _draw_gal(): - gal = galsim.Gaussian(flux=10**LOG_FLUX, half_light_radius=HLR) - gal = gal.shift(dx=X, dy=Y) - gal = gal.shear(g1=G1, g2=G2) - - psf = galsim.Gaussian(flux=1.0, half_light_radius=PSF_HLR) - gal_conv = galsim.Convolve([gal, psf]) - image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE) - return image.array - - -def draw_gal(f, hlr, g1, g2, x, y): - # x, y arguments in pixels - gal = xgalsim.Gaussian(flux=10**f, half_light_radius=hlr) - gal = gal.shift(dx=x * PIXEL_SCALE, dy=y * PIXEL_SCALE) - gal = gal.shear(g1=g1, g2=g2) - - psf = xgalsim.Gaussian(flux=1, half_light_radius=PSF_HLR) - gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(GSPARAMS) - image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE) - return image.array - - -def _logprob_fn(params, data): - # prior - prior = jnp.array(0.0, device=GPU) - for p in ("f", "hlr", "g1", "g2"): # uniform priors - b1, b2 = BOUNDS_GPU[p] - prior += stats.uniform.logpdf(params[p], b1, b2 - b1) - - for p in ("x", "y"): # normal - sigma = BOUNDS_GPU[p] - prior += stats.norm.logpdf(params[p], sigma) - - # likelihood - model = draw_gal(**params) - likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(BACKGROUND_GPU)) - likelihood = jnp.sum(likelihood_pp) - - return prior + likelihood - - -def _log_setup(snr: float): - with open(LOG_FILE, "a", encoding="utf-8") as f: - print(file=f) - print( - f"""Running benchmark chees 1 with configuration as follows. Variable number of chains. - - The sampler used is NUTS with standard warmup. - - TAG: {TAG} - SEED: {SEED} - - Overall sampler configuration (fixed): - n_samples: {N_SAMPLES} - n_warmups: {N_WARMUPS} - diagonal matrix: {IS_MATRIX_DIAGONAL} - learning_rate: {LR} - init_step_size: {INIT_STEP_SIZE} - - galaxy parameters: - LOG_FLUX: {LOG_FLUX} - HLR: {HLR} - G1: {G1} - G2: {G2} - X: {X} - Y: {Y} - - prior bounds: {BOUNDS} - - other parameters: - slen: {SLEN} - psf_hlr: {PSF_HLR} - background: {BACKGROUND} - snr: {snr} - """, - file=f, - ) - - -def do_warmup(rng_key, positions, data, n_chains: int = None): - """Cannot jit!, but seems to automatically compile after running once.""" - logdensity = partial(_logprob_fn, data=data) - warmup = blackjax.chees_adaptation(logdensity, n_chains) - optim = optax.adam(LR) - # `positions` = PyTree where each leaf has shape (num_chains, ...) - return warmup.run(rng_key, positions, INIT_STEP_SIZE, optim, N_WARMUPS) - - -def do_inference(rng_key, init_states, data, tuned_params: dict): - """Also won't jit for unknown reasons""" - _logdensity = partial(_logprob_fn, data=data) - kernel = blackjax.dynamic_hmc(_logdensity, **tuned_params).step - return inference_loop(rng_key, init_states, kernel=kernel, n_samples=N_SAMPLES) - - -def main(): - print("TAG:", TAG) - snr = get_snr(_draw_gal(), BACKGROUND) - print("galaxy snr:", snr) - - # get data - _data = add_noise(_draw_gal(), BACKGROUND, rng=np.random.default_rng(SEED), n=1)[0] - data_gpu = jax.device_put(_data, device=GPU) - print("data info:", data_gpu.devices(), type(data_gpu), data_gpu.shape) - - # collect random keys we need - rng_key = random.key(SEED) - rng_key = jax.device_put(rng_key, device=GPU) - - ball_key, warmup_key, sample_key = random.split(rng_key, 3) - - warmup_keys = random.split(warmup_key, len(ALL_N_CHAINS)) - ball_keys = random.split(ball_key, max(ALL_N_CHAINS)) - sample_keys = random.split(sample_key, max(ALL_N_CHAINS)) - assert sample_keys.shape == (max(ALL_N_CHAINS),) - - # get initial positions for all chains - all_init_positions = vmap(sample_ball, in_axes=(0, None))( - ball_keys, TRUE_PARAMS_GPU - ) - assert all_init_positions["f"].shape == (max(ALL_N_CHAINS),) - - # jit and vmap functions to run chains - _run_inference = vmap(do_inference, in_axes=(0, 0, None, None)) - - # results - results = {n: {} for n in ALL_N_CHAINS} - - for ii, n_chains in enumerate(ALL_N_CHAINS): - print(f"n_chains: {n_chains}") - - _key1 = warmup_keys[ii] - _keys2 = sample_keys[:n_chains] - _init_positions = {p: q[:n_chains] for p, q in all_init_positions.items()} - - _run_warmup = partial(do_warmup, n_chains=n_chains) - - # compilation times for warmup - t1 = time.time() - (_sts, _tp), _ = jax.block_until_ready( - _run_warmup(_key1, _init_positions, data_gpu) - ) - t2 = time.time() - results[n_chains]["warmup_comp_time"] = t2 - t1 - - # inference compilation time - if ii == 0: - t1 = time.time() - _ = jax.block_until_ready(_run_inference(_keys2, _sts, data_gpu, _tp)) - t2 = time.time() - results["inference_comp_time"] = t2 - t1 - - # run times - t1 = time.time() - (last_states, tuned_params), adapt_info = jax.block_until_ready( - _run_warmup(_key1, _init_positions, data_gpu) - ) - t2 = time.time() - results[n_chains]["warmup_run_time"] = t2 - t1 - - t1 = time.time() - states, infos = jax.block_until_ready( - _run_inference(_keys2, last_states, data_gpu, tuned_params) - ) - t2 = time.time() - results[n_chains]["inference_run_time"] = t2 - t1 - - # save states and info for future reference - results[n_chains]["states"] = states - results[n_chains]["info"] = infos - results[n_chains]["adapt_info"] = adapt_info - results[n_chains]["step_size"] = tuned_params["step_size"] - - results["data"] = data_gpu - results["init_positions"] = all_init_positions - - filename = f"results_chees_benchmark1_{TAG}.npy" - filepath = SCRATCH_DIR.joinpath(filename) - jnp.save(filepath, results) - - _log_setup(snr) - with open(LOG_FILE, "a", encoding="utf-8") as f: - print(file=f) - print(f"results were saved to {filepath}", file=f) - - -if __name__ == "__main__": - main() diff --git a/scripts/get_shear_from_interim_samples.py b/scripts/get_shear_from_shapes.py similarity index 92% rename from scripts/get_shear_from_interim_samples.py rename to scripts/get_shear_from_shapes.py index a88264b..a1ddd33 100755 --- a/scripts/get_shear_from_interim_samples.py +++ b/scripts/get_shear_from_shapes.py @@ -9,7 +9,7 @@ from bpd import DATA_DIR from bpd.io import load_dataset -from bpd.pipelines.shear_inference import pipeline_shear_inference_ellipticities +from bpd.pipelines import pipeline_shear_inference_simple def _extract_seed(fpath: str) -> int: @@ -47,7 +47,7 @@ def main( sigma_e_int = samples_dataset["sigma_e_int"] rng_key = jax.random.key(seed) - g_samples = pipeline_shear_inference_ellipticities( + g_samples = pipeline_shear_inference_simple( rng_key, e_post, init_g=true_g, diff --git a/scripts/slurm_job.py b/scripts/slurm/slurm_job.py similarity index 100% rename from scripts/slurm_job.py rename to scripts/slurm/slurm_job.py diff --git a/scripts/slurm_toy_shear_vectorized.py b/scripts/slurm/slurm_toy_shear_vectorized.py similarity index 97% rename from scripts/slurm_toy_shear_vectorized.py rename to scripts/slurm/slurm_toy_shear_vectorized.py index 66c5d0d..2407ea2 100755 --- a/scripts/slurm_toy_shear_vectorized.py +++ b/scripts/slurm/slurm_toy_shear_vectorized.py @@ -2,9 +2,9 @@ import subprocess import typer +from slurm_job import setup_sbatch_job_gpu from bpd import DATA_DIR -from scripts.slurm_job import setup_sbatch_job_gpu def main( diff --git a/scripts/toy_shear_vectorized.py b/scripts/toy_shear_vectorized.py index 37e4849..e8f3af3 100755 --- a/scripts/toy_shear_vectorized.py +++ b/scripts/toy_shear_vectorized.py @@ -7,8 +7,7 @@ from jax import random, vmap from bpd import DATA_DIR -from bpd.pipelines.shear_inference import pipeline_shear_inference_ellipticities -from bpd.pipelines.toy_ellips import pipeline_toy_ellips_samples +from bpd.pipelines import pipeline_shear_inference_simple, pipeline_toy_ellips def main( @@ -40,7 +39,7 @@ def main( n_batch = ceil(len(keys) / n_vec) pipe1 = partial( - pipeline_toy_ellips_samples, + pipeline_toy_ellips, g1=g1, g2=g2, sigma_e=shape_noise, @@ -50,7 +49,7 @@ def main( n_samples_per_gal=n_samples_per_gal, ) pipe2 = partial( - pipeline_shear_inference_ellipticities, + pipeline_shear_inference_simple, init_g=jnp.array([g1, g2]), sigma_e=shape_noise, sigma_e_int=sigma_e_int, diff --git a/tests/test_convergence.py b/tests/test_convergence.py index 12d0191..e9747d4 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -8,10 +8,12 @@ from jax import jit, random, vmap from bpd.chains import run_inference_nuts -from bpd.pipelines.shear_inference import pipeline_shear_inference_ellipticities -from bpd.pipelines.toy_ellips import logtarget as logtarget_toy_ellips -from bpd.pipelines.toy_ellips import pipeline_toy_ellips_samples -from bpd.prior import sample_noisy_ellipticities_unclipped +from bpd.pipelines import ( + logtarget_toy_ellips, + pipeline_shear_inference_simple, + pipeline_toy_ellips, +) +from bpd.sample import sample_noisy_ellipticities_unclipped @pytest.mark.parametrize("seed", [1234, 4567]) @@ -84,7 +86,7 @@ def test_toy_shear_convergence(seed): key = random.key(seed) k1, k2 = random.split(key) - e_post, _, _ = pipeline_toy_ellips_samples( + e_post, _, _ = pipeline_toy_ellips( k1, g1=g1, g2=g2, @@ -97,7 +99,7 @@ def test_toy_shear_convergence(seed): # run 4 shear chains over the given e_post _pipeline_shear1 = partial( - pipeline_shear_inference_ellipticities, + pipeline_shear_inference_simple, init_g=true_g, sigma_e=sigma_e, sigma_e_int=sigma_e_int, diff --git a/tests/test_shear_inference.py b/tests/test_shear_inference.py index eef17d2..c1afa3a 100644 --- a/tests/test_shear_inference.py +++ b/tests/test_shear_inference.py @@ -5,8 +5,7 @@ import pytest from jax import random -from bpd.pipelines.shear_inference import pipeline_shear_inference_ellipticities -from bpd.pipelines.toy_ellips import pipeline_toy_ellips_samples +from bpd.pipelines import pipeline_shear_inference_simple, pipeline_toy_ellips @pytest.mark.parametrize("seed", [1234, 4567]) @@ -22,7 +21,7 @@ def test_shear_inference_toy_ellipticities(seed): true_g = jnp.array([g1, g2]) n_gals = 1000 - e_post, _, e_sheared = pipeline_toy_ellips_samples( + e_post, _, e_sheared = pipeline_toy_ellips( k1, g1=g1, g2=g2, @@ -35,7 +34,7 @@ def test_shear_inference_toy_ellipticities(seed): assert e_post.shape == (n_gals, 100, 2) e_post_trimmed = e_post[:, ::10, :] - shear_samples = pipeline_shear_inference_ellipticities( + shear_samples = pipeline_shear_inference_simple( k2, e_post_trimmed, init_g=true_g, diff --git a/tests/test_shear_trans.py b/tests/test_shear_trans.py index 99b3ee0..70cb9b5 100644 --- a/tests/test_shear_trans.py +++ b/tests/test_shear_trans.py @@ -7,9 +7,9 @@ from jax import jit, random from jax_galsim import GSParams -from bpd.prior import ( +from bpd.sample import sample_ellip_prior +from bpd.shear import ( inv_shear_transformation, - sample_ellip_prior, scalar_inv_shear_transformation, scalar_shear_transformation, shear_transformation,