Skip to content

Commit

Permalink
refactor defaults and cli (#35)
Browse files Browse the repository at this point in the history
* avoid repeated nuts code throughout codebase

* refactor and use new function

* sigma seems off by sqrt(2)

* rename

* being careful with defaults

* being careful with defaults and renaming

* renaming

* reorder

* avoid certain defaults for now

* we dont need these ones anymore

* need to remove prior and move to target

* fix test after refactoring

* use typer

* typo

* typo

* fix tests

* rename and use typer

* rename

* more judicious

* step size, no default

* need more arguments now

* rtol change to be appropriate int est

* fix

* high snr is the default

* add typer, but finish in next PR

* rename

* rename

* fix corersponding slurm script
  • Loading branch information
ismael-mendoza authored Nov 4, 2024
1 parent 6915ae3 commit 94a0c89
Show file tree
Hide file tree
Showing 17 changed files with 483 additions and 1,387 deletions.
40 changes: 40 additions & 0 deletions bpd/chains.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from functools import partial
from typing import Callable

import blackjax
import jax
from jax import random
from jax._src.prng import PRNGKeyArray
from jax.typing import ArrayLike


def inference_loop(rng_key, initial_state, kernel, n_samples: int):
Expand All @@ -12,3 +19,36 @@ def one_step(state, rng_key):
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

return (states, infos)


def run_inference_nuts(
rng_key: PRNGKeyArray,
init_positions: ArrayLike,
data: ArrayLike,
*,
logtarget: Callable,
n_samples: int,
initial_step_size: float,
max_num_doublings: int,
n_warmup_steps: int = 500,
target_acceptance_rate: float = 0.80,
is_mass_matrix_diagonal: bool = True,
):
key1, key2 = random.split(rng_key)

_logtarget = partial(logtarget, data=data)

warmup = blackjax.window_adaptation(
blackjax.nuts,
_logtarget,
progress_bar=False,
is_mass_matrix_diagonal=is_mass_matrix_diagonal,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
target_acceptance_rate=target_acceptance_rate,
)

(init_states, tuned_params), _ = warmup.run(key1, init_positions, n_warmup_steps)
kernel = blackjax.nuts(_logtarget, **tuned_params).step
states, _ = inference_loop(key2, init_states, kernel=kernel, n_samples=n_samples)
return states.position
12 changes: 7 additions & 5 deletions bpd/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ def draw_gaussian(
g2: float,
x: float,
y: float,
pixel_scale: float = 0.2,
slen: int = 53,
*,
slen: int,
fft_size: int, # rule of thumb: at least 4 times `slen`
psf_hlr: float = 0.7,
fft_size: int = 256, # rule of thumb, at least 4 times `slen`
pixel_scale: float = 0.2,
):
gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size)

Expand All @@ -39,9 +40,10 @@ def draw_gaussian_galsim(
g2: float,
x: float, # pixels
y: float,
pixel_scale: float = 0.2,
slen: int = 53,
*,
slen: int,
psf_hlr: float = 0.7,
pixel_scale: float = 0.2,
):
gal = galsim.Gaussian(flux=f, half_light_radius=hlr)
gal = gal.shear(g1=e1, g2=e2)
Expand Down
79 changes: 14 additions & 65 deletions bpd/pipelines/image_ellips.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from functools import partial
from typing import Callable

import blackjax
import jax.numpy as jnp
from jax import Array, random
from jax import jit as jjit
from jax._src.prng import PRNGKeyArray
from jax.scipy import stats

from bpd.chains import inference_loop
from bpd.chains import run_inference_nuts
from bpd.draw import draw_gaussian, draw_gaussian_galsim
from bpd.noise import add_noise
from bpd.prior import ellip_mag_prior, sample_ellip_prior
Expand All @@ -17,7 +16,7 @@
def get_target_galaxy_params_simple(
rng_key: PRNGKeyArray,
shape_noise: float = 1e-3,
lf: float = 3.0,
lf: float = 6.0,
hlr: float = 1.0,
x: float = 0.0, # pixels
y: float = 0.0,
Expand All @@ -42,29 +41,24 @@ def get_target_images_single(
rng_key: PRNGKeyArray,
n_samples: int,
single_galaxy_params: dict[str, float],
psf_hlr: float = 0.7,
background: float = 1.0,
slen: int = 53,
pixel_scale: float = 0.2,
*,
background: float,
slen: int,
):
"""In this case, we sample multiple noise realizations of the same galaxy."""
assert "f" in single_galaxy_params and "lf" not in single_galaxy_params

noiseless = draw_gaussian_galsim(
**single_galaxy_params,
pixel_scale=pixel_scale,
psf_hlr=psf_hlr,
slen=slen,
)
noiseless = draw_gaussian_galsim(**single_galaxy_params, slen=slen)
return add_noise(rng_key, noiseless, bg=background, n=n_samples), noiseless


# interim prior
def logprior(
params: dict[str, Array],
*,
sigma_e: float,
flux_bds: tuple = (-1.0, 9.0),
hlr_bds: tuple = (0.01, 5.0),
sigma_e: float = 3e-2,
sigma_x: float = 1.0, # pixels
) -> Array:
prior = jnp.array(0.0)
Expand Down Expand Up @@ -107,71 +101,27 @@ def logtarget(
return logprior_fnc(params) + loglikelihood_fnc(params, data)


def do_inference(
rng_key: PRNGKeyArray,
init_positions: dict[str, Array],
data: Array,
*,
logtarget_fnc: Callable,
is_mass_matrix_diagonal: bool = False,
n_warmup_steps: int = 500,
max_num_doublings: int = 5,
initial_step_size: float = 1e-3,
target_acceptance_rate: float = 0.80,
n_samples: int = 100,
):
key1, key2 = random.split(rng_key)

_logdensity = partial(logtarget_fnc, data=data)

warmup = blackjax.window_adaptation(
blackjax.nuts,
_logdensity,
progress_bar=False,
is_mass_matrix_diagonal=is_mass_matrix_diagonal,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
target_acceptance_rate=target_acceptance_rate,
)

(init_states, tuned_params), _ = warmup.run(key1, init_positions, n_warmup_steps)

kernel = blackjax.nuts(_logdensity, **tuned_params).step
states, _ = inference_loop(key2, init_states, kernel=kernel, n_samples=n_samples)

return states.position


def pipeline_image_interim_samples(
def pipeline_image_interim_samples_one_galaxy(
rng_key: PRNGKeyArray,
true_params: dict[str, float],
target_image: Array,
*,
initialization_fnc: Callable,
sigma_e_int: float = 3e-2,
sigma_e_int: float,
n_samples: int = 100,
max_num_doublings: int = 5,
initial_step_size: float = 1e-3,
target_acceptance_rate: float = 0.80,
n_warmup_steps: int = 500,
is_mass_matrix_diagonal: bool = False,
slen: int = 53,
pixel_scale: float = 0.2,
psf_hlr: float = 0.7,
background: float = 1.0,
fft_size: int = 256,
background: float = 1.0,
):
k1, k2 = random.split(rng_key)

init_position = initialization_fnc(k1, true_params=true_params, data=target_image)

_draw_fnc = partial(
draw_gaussian,
pixel_scale=pixel_scale,
slen=slen,
psf_hlr=psf_hlr,
fft_size=fft_size,
)
_draw_fnc = partial(draw_gaussian, slen=slen, fft_size=fft_size)
_loglikelihood = partial(loglikelihood, draw_fnc=_draw_fnc, background=background)
_logprior = partial(logprior, sigma_e=sigma_e_int)

Expand All @@ -180,13 +130,12 @@ def pipeline_image_interim_samples(
)

_inference_fnc = partial(
do_inference,
logtarget_fnc=_logtarget,
run_inference_nuts,
logtarget=_logtarget,
is_mass_matrix_diagonal=is_mass_matrix_diagonal,
n_warmup_steps=n_warmup_steps,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
target_acceptance_rate=target_acceptance_rate,
n_samples=n_samples,
)
_run_inference = jjit(_inference_fnc)
Expand Down
43 changes: 12 additions & 31 deletions bpd/pipelines/shear_inference.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,34 @@
from functools import partial
from typing import Callable

import blackjax
from jax import Array, random
from jax import Array
from jax import jit as jjit
from jax._src.prng import PRNGKeyArray
from jax.scipy import stats

from bpd.chains import inference_loop
from bpd.chains import run_inference_nuts
from bpd.likelihood import shear_loglikelihood
from bpd.prior import ellip_mag_prior


def logtarget_density(g: Array, e_post: Array, loglikelihood: Callable):
def logtarget_density(g: Array, *, data: Array, loglikelihood: Callable):
e_post = data # comptability with `do_inference_nuts`
loglike = loglikelihood(g, e_post)
logprior = stats.uniform.logpdf(g, -0.1, 0.2).sum()
return logprior + loglike


def do_inference(
rng_key: PRNGKeyArray,
init_g: Array,
logtarget: Callable,
n_samples: int,
n_warmup_steps: int = 500,
):
key1, key2 = random.split(rng_key)

warmup = blackjax.window_adaptation(
blackjax.nuts,
logtarget,
progress_bar=False,
is_mass_matrix_diagonal=True,
max_num_doublings=2,
initial_step_size=1e-2,
target_acceptance_rate=0.80,
)

(init_states, tuned_params), _ = warmup.run(key1, init_g, n_warmup_steps)
kernel = blackjax.nuts(logtarget, **tuned_params).step
states, _ = inference_loop(key2, init_states, kernel=kernel, n_samples=n_samples)
return states.position


def pipeline_shear_inference(
rng_key: PRNGKeyArray,
e_post: Array,
*,
true_g: Array,
sigma_e: float,
sigma_e_int: float,
n_samples: int,
initial_step_size: float,
n_warmup_steps: int = 500,
max_num_doublings: int = 2,
):
prior = partial(ellip_mag_prior, sigma=sigma_e)
interim_prior = partial(ellip_mag_prior, sigma=sigma_e_int)
Expand All @@ -59,13 +37,16 @@ def pipeline_shear_inference(
_loglikelihood = jjit(
partial(shear_loglikelihood, prior=prior, interim_prior=interim_prior)
)
_logtarget = partial(logtarget_density, loglikelihood=_loglikelihood, e_post=e_post)
_logtarget = partial(logtarget_density, loglikelihood=_loglikelihood)

_do_inference = partial(
do_inference,
run_inference_nuts,
data=e_post,
logtarget=_logtarget,
n_samples=n_samples,
n_warmup_steps=n_warmup_steps,
max_num_doublings=max_num_doublings,
initial_step_size=initial_step_size,
)

g_samples = _do_inference(rng_key, true_g)
Expand Down
Loading

0 comments on commit 94a0c89

Please sign in to comment.