Skip to content

Commit

Permalink
JP-3657: Part 1 of a NIRSpec BOTS speedup (#8609)
Browse files Browse the repository at this point in the history
  • Loading branch information
melanieclarke authored Jul 24, 2024
2 parents 50ec28d + ebd929c commit ce95ab6
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 29 deletions.
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ extract_1d
all slits containing point sources are now handled consistently,
whether they are marked primary or not. [#8467]

- Added functionality to the phase-based aperture correction object to support
reuse of aperture correction objects across multiple integrations. [#8609]

- Changed extract.py to attempt to tabulate and reuse an aperture correction
object in integrations after the first one. This can save a very large
amount of time in BOTS reductions. [#8609]

extract_2d
----------

Expand Down
59 changes: 50 additions & 9 deletions jwst/extract_1d/apply_apcorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, input_model: DataModel, apcorr_table: fits.FITS_rec, sizeunit
self.reference = self._reduce_reftable()
self._convert_size_units()
self.apcorr_func = self.approximate()
self.tabulated_correction = None

def _convert_size_units(self):
"""If the SIZE or Radius column is in units of arcseconds, convert to pixels."""
Expand Down Expand Up @@ -227,30 +228,70 @@ def _approx_func(wavelength: float, size: float, pixel_phase: float) -> RectBiva
def measure_phase(self): # Future method in determining pixel phase
pass

def apply(self, spec_table: fits.FITS_rec):
"""Apply interpolated aperture correction value to source-related extraction results in-place.
def tabulate_correction(self, spec_table: fits.FITS_rec):
"""Tabulate the interpolated aperture correction value.
This will save time when applying it later, especially if it is to be applied to multiple integrations.
Modifies self.tabulated_correction.
Parameters
----------
spec_table : `~fits.FITS_rec`
Table of aperture corrections values from apcorr reference file.
"""
flux_cols_to_correct = ('flux', 'flux_error', 'surf_bright', 'sb_error')
var_cols_to_correct = ('flux_var_poisson', 'flux_var_rnoise', 'flux_var_flat',
'sb_var_poisson', 'sb_var_rnoise', 'sb_var_flat')

coefs = []
for row in spec_table:
try:
correction = self.apcorr_func(row['wavelength'], row['npixels'], self.phase)
except ValueError:
correction = None # Some input wavelengths might not be supported (especially at the ends of the range)

if correction:
for col in flux_cols_to_correct:
row[col] *= correction.item()
for col in var_cols_to_correct:
row[col] *= correction.item() * correction.item()
coefs += [correction.item()]
else:
coefs += [1]

self.tabulated_correction = np.asarray(coefs)

def apply(self, spec_table: fits.FITS_rec, use_tabulated=False):
"""Apply interpolated aperture correction value to source-related extraction results in-place.
Parameters
----------
spec_table : `~fits.FITS_rec`
Table of aperture corrections values from apcorr reference file.
use_tabulated : bool, Optional
Use self.tabulated_correction to perform the aperture correction?
Default False (recompute correction from scratch).
"""
flux_cols_to_correct = ('flux', 'flux_error', 'surf_bright', 'sb_error')
var_cols_to_correct = ('flux_var_poisson', 'flux_var_rnoise', 'flux_var_flat',
'sb_var_poisson', 'sb_var_rnoise', 'sb_var_flat')

if use_tabulated:
if self.tabulated_correction is None:
raise ValueError("Cannot call apply_tabulated_correction without first "
"calling tabulate_correction")

for col in flux_cols_to_correct:
spec_table[col] *= self.tabulated_correction
for col in var_cols_to_correct:
spec_table[col] *= self.tabulated_correction**2
else:
for row in spec_table:
try:
correction = self.apcorr_func(row['wavelength'], row['npixels'], self.phase)
except ValueError:
correction = None # Some input wavelengths might not be supported (especially at the ends of the range)

if correction:
for col in flux_cols_to_correct:
row[col] *= correction.item()
for col in var_cols_to_correct:
row[col] *= correction.item() * correction.item()


class ApCorrRadial(ApCorrBase):
Expand Down
72 changes: 52 additions & 20 deletions jwst/extract_1d/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3783,6 +3783,8 @@ def create_extraction(extract_ref_dict,
log.info(f"Beginning loop over {shape[0]} integrations ...")
integrations = range(shape[0])

ra_last = dec_last = wl_last = apcorr = None

for integ in integrations:
try:
ra, dec, wavelength, temp_flux, f_var_poisson, f_var_rnoise, \
Expand Down Expand Up @@ -3896,27 +3898,57 @@ def create_extraction(extract_ref_dict,
else:
wl = wavelength.min()

if isinstance(input_model, datamodels.ImageModel):
apcorr = select_apcorr(input_model)(
input_model,
apcorr_ref_model.apcorr_table,
apcorr_ref_model.sizeunit,
location=(ra, dec, wl)
)
else:
match_kwargs = {'location': (ra, dec, wl)}
if exp_type in ['NRS_FIXEDSLIT', 'NRS_BRIGHTOBJ']:
match_kwargs['slit'] = slitname

apcorr = select_apcorr(input_model)(
input_model,
apcorr_ref_model.apcorr_table,
apcorr_ref_model.sizeunit,
slit_name=slitname,
**match_kwargs
)
apcorr.apply(spec.spec_table)
# Determine whether we have a tabulated aperture correction
# available to save time.

apcorr_available = False
if apcorr is not None:
if hasattr(apcorr, 'tabulated_correction'):
if apcorr.tabulated_correction is not None:
apcorr_available = True

# See whether we can reuse the previous aperture correction
# object. If so, just apply the pre-computed correction to
# save a ton of time.
if ra == ra_last and dec == dec_last and wl == wl_last and apcorr_available:
# re-use the last aperture correction
apcorr.apply(spec.spec_table, use_tabulated=True)

else:
if isinstance(input_model, datamodels.ImageModel):
apcorr = select_apcorr(input_model)(
input_model,
apcorr_ref_model.apcorr_table,
apcorr_ref_model.sizeunit,
location=(ra, dec, wl)
)
else:
match_kwargs = {'location': (ra, dec, wl)}
if exp_type in ['NRS_FIXEDSLIT', 'NRS_BRIGHTOBJ']:
match_kwargs['slit'] = slitname

apcorr = select_apcorr(input_model)(
input_model,
apcorr_ref_model.apcorr_table,
apcorr_ref_model.sizeunit,
slit_name=slitname,
**match_kwargs
)
# Attempt to tabulate the aperture correction for later use.
# If this fails, fall back on the old method.
try:
apcorr.tabulate_correction(spec.spec_table)
apcorr.apply(spec.spec_table, use_tabulated=True)
log.info("Tabulating aperture correction for use in multiple integrations.")
except AttributeError:
log.info("Computing aperture correction.")
apcorr.apply(spec.spec_table)

# Save previous ra, dec, wavelength in case we can reuse
# the aperture correction object.
ra_last = ra
dec_last = dec
wl_last = wl
output_model.spec.append(spec)

if log_increment > 0 and (integ + 1) % log_increment == 0:
Expand Down

0 comments on commit ce95ab6

Please sign in to comment.