Skip to content

Commit

Permalink
Refactoring (jan 2025 edition) (#80)
Browse files Browse the repository at this point in the history
* benchmarks have been superseeded, might revive later

* might put more random stuff here

* all reusable priors and likelihoods go here!

* noise is part of drawing?

* shear transformations now go here

* reusable and most generalized pipelines go here

* generating samples of parameters or imgaes

* pipelines got consolidated

* fix one test

* it's useful to keep the simple pipeline that only does ellipticities

* fix another test

* much less needs to be added see

* fix last test

* organize scripts

* fix experiments

* unchange rerun
  • Loading branch information
ismael-mendoza authored Jan 14, 2025
1 parent 861b34f commit 3d48d44
Show file tree
Hide file tree
Showing 48 changed files with 527 additions and 2,392 deletions.
21 changes: 21 additions & 0 deletions bpd/draw.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import galsim
import jax.numpy as jnp
import jax_galsim as xgalsim
from jax import random
from jax._src.prng import PRNGKeyArray
from jax.typing import ArrayLike
from jax_galsim import GSParams


Expand Down Expand Up @@ -50,3 +54,20 @@ def draw_gaussian_galsim(
gal_conv = galsim.Convolve([gal, psf])
image = gal_conv.drawImage(nx=slen, ny=slen, scale=pixel_scale, offset=(x, y))
return image.array


def add_noise(
rng_key: PRNGKeyArray,
x: ArrayLike,
bg: float,
n: int = 1,
):
"""Produce `n` independent Gaussian noise realizations of a given image `x`.
NOTE: This function assumes image is background-subtracted and dominated.
"""
assert isinstance(bg, float) or bg.shape == ()
x = x.reshape(1, *x.shape)
x = x.repeat(n, axis=0)
noise = random.normal(rng_key, shape=x.shape) * jnp.sqrt(bg)
return x + noise
67 changes: 39 additions & 28 deletions bpd/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,14 @@

import jax.numpy as jnp
import jax.scipy as jsp
from jax import Array, grad, vmap
from jax import Array
from jax.scipy import stats
from jax.typing import ArrayLike

from bpd.prior import (
ellip_prior_e1e2,
inv_shear_func1,
inv_shear_func2,
inv_shear_transformation,
)

_grad_fnc1 = vmap(vmap(grad(inv_shear_func1), in_axes=(0, None)), in_axes=(0, None))
_grad_fnc2 = vmap(vmap(grad(inv_shear_func2), in_axes=(0, None)), in_axes=(0, None))
_inv_shear_trans = vmap(inv_shear_transformation, in_axes=(0, None))


def true_ellip_logprior(e_post: Array, g: Array, *, sigma_e: float):
"""Implementation of GB's true prior on interim posterior samples of ellipticities."""

# jacobian of inverse shear transformation
grad1 = _grad_fnc1(e_post, g)
grad2 = _grad_fnc2(e_post, g)
absjacdet = jnp.abs(grad1[..., 0] * grad2[..., 1] - grad1[..., 1] * grad2[..., 0])

# true prior on unsheared ellipticity
e_post_unsheared = _inv_shear_trans(e_post, g)
prior_val = ellip_prior_e1e2(e_post_unsheared, sigma=sigma_e)

return jnp.log(prior_val) + jnp.log(absjacdet)


def shear_loglikelihood(
g: Array,
post_params: dict[str, Array],
post_params: dict[str, Array] | Array,
*,
logprior: Callable,
interim_logprior: Callable, # fixed
Expand All @@ -44,3 +19,39 @@ def shear_loglikelihood(
num = logprior(post_params, g)
ratio = jsp.special.logsumexp(num - denom, axis=-1)
return ratio.sum()


def gaussian_image_loglikelihood(
params: dict[str, Array],
data: Array,
fixed_params: dict[str, Array],
*,
draw_fnc: Callable,
background: float,
free_flux_hlr: bool = True,
free_dxdy: bool = True,
):
_draw_params = {}

if free_dxdy:
_draw_params["x"] = params["dx"] + fixed_params["x"]
_draw_params["y"] = params["dy"] + fixed_params["y"]

else:
_draw_params["x"] = fixed_params["x"]
_draw_params["y"] = fixed_params["y"]

if free_flux_hlr:
_draw_params["f"] = 10 ** params["lf"]
_draw_params["hlr"] = 10 ** params["lhlr"]

else:
_draw_params["f"] = fixed_params["f"]
_draw_params["hlr"] = fixed_params["hlr"]

_draw_params["e1"] = params["e1"]
_draw_params["e2"] = params["e2"]

model = draw_fnc(**_draw_params)
likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(background))
return jnp.sum(likelihood_pp)
21 changes: 0 additions & 21 deletions bpd/noise.py

This file was deleted.

207 changes: 207 additions & 0 deletions bpd/pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from functools import partial
from typing import Callable

import jax.numpy as jnp
import jax.scipy as jsp
from jax import Array, jit, random, vmap
from jax._src.prng import PRNGKeyArray
from jax.scipy import stats

from bpd.chains import run_inference_nuts
from bpd.likelihood import shear_loglikelihood
from bpd.prior import ellip_prior_e1e2, true_ellip_logprior
from bpd.sample import sample_noisy_ellipticities_unclipped


def logtarget_shear(
g: Array, *, data: Array | dict[str, Array], loglikelihood: Callable, sigma_g: float
):
loglike = loglikelihood(g, post_params=data)
logprior = stats.norm.logpdf(g, loc=0.0, scale=sigma_g).sum()
return logprior + loglike


def pipeline_shear_inference(
rng_key: PRNGKeyArray,
post_params: Array,
init_g: Array,
*,
logprior: Callable,
interim_logprior: Callable,
n_samples: int,
initial_step_size: float,
sigma_g: float = 0.01,
n_warmup_steps: int = 500,
max_num_doublings: int = 2,
):
_loglikelihood = partial(
shear_loglikelihood, logprior=logprior, interim_logprior=interim_logprior
)
_loglikelihood_jitted = jit(_loglikelihood)

_logtarget = partial(
logtarget_shear, loglikelihood=_loglikelihood_jitted, sigma_g=sigma_g
)

_do_inference = partial(
run_inference_nuts,
data=post_params,
logtarget=_logtarget,
n_samples=n_samples,
n_warmup_steps=n_warmup_steps,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
)

g_samples = _do_inference(rng_key, init_g)
return g_samples


def pipeline_shear_inference_simple(
rng_key: PRNGKeyArray,
e_post: Array,
init_g: Array,
*,
sigma_e: float,
sigma_e_int: float,
n_samples: int,
initial_step_size: float,
sigma_g: float = 0.01,
n_warmup_steps: int = 500,
max_num_doublings: int = 2,
):
_logprior = lambda e, g: true_ellip_logprior(e, g, sigma_e=sigma_e)
_interim_logprior = lambda e: jnp.log(ellip_prior_e1e2(e, sigma=sigma_e_int))

_loglikelihood = partial(
shear_loglikelihood, logprior=_logprior, interim_logprior=_interim_logprior
)
_loglikelihood_jitted = jit(_loglikelihood)

_logtarget = partial(
logtarget_shear, loglikelihood=_loglikelihood_jitted, sigma_g=sigma_g
)

_do_inference = partial(
run_inference_nuts,
data=e_post,
logtarget=_logtarget,
n_samples=n_samples,
n_warmup_steps=n_warmup_steps,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
)

return _do_inference(rng_key, init_g)


def logtarget_images(
params: dict[str, Array],
data: Array,
*,
fixed_params: dict[str, Array],
logprior_fnc: Callable,
loglikelihood_fnc: Callable,
):
return logprior_fnc(params) + loglikelihood_fnc(params, data, fixed_params)


def pipeline_interim_samples_one_galaxy(
rng_key: PRNGKeyArray,
true_params: dict[str, float],
target_image: Array,
fixed_params: dict[str, float],
*,
initialization_fnc: Callable,
logprior: Callable,
loglikelihood: Callable,
n_samples: int = 100,
max_num_doublings: int = 5,
initial_step_size: float = 1e-3,
n_warmup_steps: int = 500,
is_mass_matrix_diagonal: bool = True,
):
# Flux and HLR are fixed to truth and not inferred in this function.
k1, k2 = random.split(rng_key)

init_position = initialization_fnc(k1, true_params=true_params, data=target_image)

_logtarget = partial(
logtarget_images,
logprior_fnc=logprior,
loglikelihood_fnc=loglikelihood,
fixed_params=fixed_params,
)

_inference_fnc = partial(
run_inference_nuts,
logtarget=_logtarget,
is_mass_matrix_diagonal=is_mass_matrix_diagonal,
n_warmup_steps=n_warmup_steps,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
n_samples=n_samples,
)
_run_inference = jit(_inference_fnc)

interim_samples = _run_inference(k2, init_position, target_image)
return interim_samples


def logtarget_toy_ellips(
e_sheared: Array,
*,
data: Array, # renamed from `e_obs` for comptability with `do_inference_nuts`
sigma_m: float,
sigma_e_int: float,
):
e_obs = data
assert e_sheared.shape == (2,) and e_obs.shape == (2,)

prior = jnp.log(ellip_prior_e1e2(e_sheared, sigma=sigma_e_int))
likelihood = jnp.sum(jsp.stats.norm.logpdf(e_obs, loc=e_sheared, scale=sigma_m))
return prior + likelihood


def pipeline_toy_ellips(
key: PRNGKeyArray,
*,
g1: float,
g2: float,
sigma_e: float,
sigma_e_int: float,
sigma_m: float,
n_gals: int,
n_samples_per_gal: int,
n_warmup_steps: int = 500,
max_num_doublings: int = 2,
):
k1, k2 = random.split(key)

true_g = jnp.array([g1, g2])

e_obs, e_sheared, _ = sample_noisy_ellipticities_unclipped(
k1, g=true_g, sigma_m=sigma_m, sigma_e=sigma_e, n=n_gals
)

_logtarget = partial(logtarget_toy_ellips, sigma_m=sigma_m, sigma_e_int=sigma_e_int)

keys2 = random.split(k2, n_gals)
_do_inference_jitted = jit(
partial(
run_inference_nuts,
logtarget=_logtarget,
n_samples=n_samples_per_gal,
initial_step_size=max(sigma_e, sigma_m),
max_num_doublings=max_num_doublings,
n_warmup_steps=n_warmup_steps,
)
)
_do_inference = vmap(_do_inference_jitted, in_axes=(0, 0, 0))

# compile
_ = _do_inference(keys2[:2], e_sheared[:2], e_obs[:2])

e_post = _do_inference(keys2, e_sheared, e_obs)

return e_post, e_obs, e_sheared
Loading

0 comments on commit 3d48d44

Please sign in to comment.