diff --git a/CHANGES.rst b/CHANGES.rst index c581f1cb4c..612044094f 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -201,6 +201,13 @@ extract_1d all slits containing point sources are now handled consistently, whether they are marked primary or not. [#8467] +- Added functionality to the phase-based aperture correction object to support + reuse of aperture correction objects across multiple integrations. [#8609] + +- Changed extract.py to attempt to tabulate and reuse an aperture correction + object in integrations after the first one. This can save a very large + amount of time in BOTS reductions. [#8609] + extract_2d ---------- diff --git a/jwst/extract_1d/apply_apcorr.py b/jwst/extract_1d/apply_apcorr.py index d39ba45235..daf9f67588 100644 --- a/jwst/extract_1d/apply_apcorr.py +++ b/jwst/extract_1d/apply_apcorr.py @@ -73,6 +73,7 @@ def __init__(self, input_model: DataModel, apcorr_table: fits.FITS_rec, sizeunit self.reference = self._reduce_reftable() self._convert_size_units() self.apcorr_func = self.approximate() + self.tabulated_correction = None def _convert_size_units(self): """If the SIZE or Radius column is in units of arcseconds, convert to pixels.""" @@ -227,8 +228,11 @@ def _approx_func(wavelength: float, size: float, pixel_phase: float) -> RectBiva def measure_phase(self): # Future method in determining pixel phase pass - def apply(self, spec_table: fits.FITS_rec): - """Apply interpolated aperture correction value to source-related extraction results in-place. + def tabulate_correction(self, spec_table: fits.FITS_rec): + """Tabulate the interpolated aperture correction value. + + This will save time when applying it later, especially if it is to be applied to multiple integrations. + Modifies self.tabulated_correction. Parameters ---------- @@ -236,10 +240,8 @@ def apply(self, spec_table: fits.FITS_rec): Table of aperture corrections values from apcorr reference file. """ - flux_cols_to_correct = ('flux', 'flux_error', 'surf_bright', 'sb_error') - var_cols_to_correct = ('flux_var_poisson', 'flux_var_rnoise', 'flux_var_flat', - 'sb_var_poisson', 'sb_var_rnoise', 'sb_var_flat') + coefs = [] for row in spec_table: try: correction = self.apcorr_func(row['wavelength'], row['npixels'], self.phase) @@ -247,10 +249,49 @@ def apply(self, spec_table: fits.FITS_rec): correction = None # Some input wavelengths might not be supported (especially at the ends of the range) if correction: - for col in flux_cols_to_correct: - row[col] *= correction.item() - for col in var_cols_to_correct: - row[col] *= correction.item() * correction.item() + coefs += [correction.item()] + else: + coefs += [1] + + self.tabulated_correction = np.asarray(coefs) + + def apply(self, spec_table: fits.FITS_rec, use_tabulated=False): + """Apply interpolated aperture correction value to source-related extraction results in-place. + + Parameters + ---------- + spec_table : `~fits.FITS_rec` + Table of aperture corrections values from apcorr reference file. + use_tabulated : bool, Optional + Use self.tabulated_correction to perform the aperture correction? + Default False (recompute correction from scratch). + + """ + flux_cols_to_correct = ('flux', 'flux_error', 'surf_bright', 'sb_error') + var_cols_to_correct = ('flux_var_poisson', 'flux_var_rnoise', 'flux_var_flat', + 'sb_var_poisson', 'sb_var_rnoise', 'sb_var_flat') + + if use_tabulated: + if self.tabulated_correction is None: + raise ValueError("Cannot call apply_tabulated_correction without first " + "calling tabulate_correction") + + for col in flux_cols_to_correct: + spec_table[col] *= self.tabulated_correction + for col in var_cols_to_correct: + spec_table[col] *= self.tabulated_correction**2 + else: + for row in spec_table: + try: + correction = self.apcorr_func(row['wavelength'], row['npixels'], self.phase) + except ValueError: + correction = None # Some input wavelengths might not be supported (especially at the ends of the range) + + if correction: + for col in flux_cols_to_correct: + row[col] *= correction.item() + for col in var_cols_to_correct: + row[col] *= correction.item() * correction.item() class ApCorrRadial(ApCorrBase): diff --git a/jwst/extract_1d/extract.py b/jwst/extract_1d/extract.py index 32442c8f83..03b36b8968 100644 --- a/jwst/extract_1d/extract.py +++ b/jwst/extract_1d/extract.py @@ -3783,6 +3783,8 @@ def create_extraction(extract_ref_dict, log.info(f"Beginning loop over {shape[0]} integrations ...") integrations = range(shape[0]) + ra_last = dec_last = wl_last = apcorr = None + for integ in integrations: try: ra, dec, wavelength, temp_flux, f_var_poisson, f_var_rnoise, \ @@ -3896,27 +3898,57 @@ def create_extraction(extract_ref_dict, else: wl = wavelength.min() - if isinstance(input_model, datamodels.ImageModel): - apcorr = select_apcorr(input_model)( - input_model, - apcorr_ref_model.apcorr_table, - apcorr_ref_model.sizeunit, - location=(ra, dec, wl) - ) - else: - match_kwargs = {'location': (ra, dec, wl)} - if exp_type in ['NRS_FIXEDSLIT', 'NRS_BRIGHTOBJ']: - match_kwargs['slit'] = slitname - - apcorr = select_apcorr(input_model)( - input_model, - apcorr_ref_model.apcorr_table, - apcorr_ref_model.sizeunit, - slit_name=slitname, - **match_kwargs - ) - apcorr.apply(spec.spec_table) + # Determine whether we have a tabulated aperture correction + # available to save time. + + apcorr_available = False + if apcorr is not None: + if hasattr(apcorr, 'tabulated_correction'): + if apcorr.tabulated_correction is not None: + apcorr_available = True + + # See whether we can reuse the previous aperture correction + # object. If so, just apply the pre-computed correction to + # save a ton of time. + if ra == ra_last and dec == dec_last and wl == wl_last and apcorr_available: + # re-use the last aperture correction + apcorr.apply(spec.spec_table, use_tabulated=True) + else: + if isinstance(input_model, datamodels.ImageModel): + apcorr = select_apcorr(input_model)( + input_model, + apcorr_ref_model.apcorr_table, + apcorr_ref_model.sizeunit, + location=(ra, dec, wl) + ) + else: + match_kwargs = {'location': (ra, dec, wl)} + if exp_type in ['NRS_FIXEDSLIT', 'NRS_BRIGHTOBJ']: + match_kwargs['slit'] = slitname + + apcorr = select_apcorr(input_model)( + input_model, + apcorr_ref_model.apcorr_table, + apcorr_ref_model.sizeunit, + slit_name=slitname, + **match_kwargs + ) + # Attempt to tabulate the aperture correction for later use. + # If this fails, fall back on the old method. + try: + apcorr.tabulate_correction(spec.spec_table) + apcorr.apply(spec.spec_table, use_tabulated=True) + log.info("Tabulating aperture correction for use in multiple integrations.") + except AttributeError: + log.info("Computing aperture correction.") + apcorr.apply(spec.spec_table) + + # Save previous ra, dec, wavelength in case we can reuse + # the aperture correction object. + ra_last = ra + dec_last = dec + wl_last = wl output_model.spec.append(spec) if log_increment > 0 and (integ + 1) % log_increment == 0: