Skip to content

Commit

Permalink
Rebase nrs deepcopy (spacetelescope#8793)
Browse files Browse the repository at this point in the history
Co-authored-by: Timothy Brandt <[email protected]>
Co-authored-by: Timothy D Brandt <[email protected]>
Co-authored-by: Melanie Clarke <[email protected]>
  • Loading branch information
4 people authored Sep 20, 2024
1 parent a327681 commit 9fdf2c1
Show file tree
Hide file tree
Showing 12 changed files with 248 additions and 43 deletions.
34 changes: 32 additions & 2 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@ assign_wcs
- Moved `update_s_region_imaging`, `update_s_region_keyword`, and `wcs_from_footprints`
into stcal. [#8624]

- Add helper functions to copy only the necessary parts of the WCS so that
these parts can be used within loops, avoiding copying the full WCS within
a loop [#8793]

associations
------------

- Restored slit name to level 3 product names for NIRSpec BOTS and background
fixed slit targets. [#8699]

- Update warning message about use of paths in associations. [#8752]

- Remove ``MultilineLogger`` and no longer set it as the default logger. [#8781]
Expand Down Expand Up @@ -70,12 +74,14 @@ cube_build

- Removed direct setting of the ``self.skip`` attribute from within the step
itself. [#8600]

- Fixed a bug when ``cube_build`` was called from the ``mrs_imatch`` step. [#8728]

- Ensure that NaNs and DO_NOT_USE flags match up in all input data before
building a cube. [#8557]

- Replaced deep copies of NIRSpec WCS objects within most loops. [#8793]

datamodels
----------

Expand All @@ -102,6 +108,8 @@ flat_field
- Ensure that NaNs and DO_NOT_USE flags match up in all science, error,
variance, and DQ extensions for all modes. [#8557]

- Replaced deep copies of NIRSpec WCS objects within most loops [#8793]

general
-------

Expand All @@ -127,6 +135,8 @@ master_background
- Either of ``"background"`` or ``"bkg"`` in slit name now defines the slit
as a background slit, instead of ``"bkg"`` only. [#8600]

- Replaced deep copies of NIRSpec WCS objects within most loops [#8793]

model_blender
-------------

Expand All @@ -137,6 +147,11 @@ mrs_imatch

- Added a deprecation warning and set the default to skip=True for the step. [#8728]

msaflagopen
-----------

- Replaced deep copies of NIRSpec WCS objects within most loops. [#8793]

nsclean
-------

Expand All @@ -149,6 +164,8 @@ nsclean
can still be called from the ``calwebb_spec2`` pipeline on NIRSpec rate
data, but it is now deprecated. [#8669]

- Replaced deep copies of NIRSpec WCS objects within most loops. [#8793]

outlier_detection
-----------------

Expand All @@ -174,12 +191,16 @@ pathloss
- Ensure that NaNs and DO_NOT_USE flags match up in all output science, error,
variance, and DQ extensions. [#8557]

- Replaced deep copies of NIRSpec WCS objects within most loops [#8793]

photom
------

- Ensure that NaNs and DO_NOT_USE flags match up in all output science, error,
variance, and DQ extensions. [#8557]

- Replaced deep copies of NIRSpec WCS objects within most loops. [#8793]

pipeline
--------

Expand All @@ -192,6 +213,13 @@ pipeline

- Updated `calwebb_spec3` to not save the `pixel_replacement` output by default.[#8765]

- Replaced deep copies of NIRSpec WCS objects within most loops. [#8793]

pixel_replace
-------------

- Replaced deep copies of NIRSpec WCS objects within most loops. [#8793]

ramp_fitting
------------

Expand Down Expand Up @@ -522,6 +550,7 @@ master_background
wavelength range instead of NaN to avoid NaN-ing out entire
sets of science data when backgrounds are missing. [#8597]


master_background_mos
---------------------

Expand Down Expand Up @@ -598,6 +627,7 @@ photom
- Added a hook to bypass the ``photom`` step when the ``extract_1d`` step
was bypassed for non-TSO NIRISS SOSS exposures. [#8575]


pipeline
--------

Expand Down
128 changes: 127 additions & 1 deletion jwst/assign_wcs/nirspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import logging
import numpy as np
import copy

from astropy.modeling import models
from astropy.modeling.models import Mapping, Identity, Const1D, Scale, Tabular1D
Expand Down Expand Up @@ -1696,6 +1697,131 @@ def gwa_to_ymsa(msa2gwa_model, lam_cen=None, slit=None, slit_y_range=None):
return tab


def _get_transforms(input_model, slitnames, return_slits=False):

"""
Return a WCS object with necessary transforms for all slits.
This function enables the JWST pipeline to avoid excessive deep
copying of WCS objects in later steps. It is used internally in
the pipeline only and should not be used if any of the WCSs is
modified.
Parameters
----------
input_model : `~jwst.datamodels.JwstDataModel`
A data model with a WCS object for the all open slitlets in
an observation.
slitnames : list of int or str
Slit.name of all open slits.
return_slits : bool, optional
Return the open slits
Returns
-------
wcsobj : `~gwcs.wcs.WCS`
WCS object deep copied from input_model.meta.wcs
sca2gwa : `~astropy.modeling.core.Model`
Transform from ``sca`` to ``gwa``
gwa2slit : list of `~astropy.modeling.core.Model`
Transform from ``gwa`` to ``slit`` for each input slit
slit2slicer : list of `~astropy.modeling.core.Model`
Transform from ``slit_frame`` to ``slicer`` for each input slit
open_slits : list of `~stdatamodels.jwst.transforms.models.Slit`
open slits from wcs.get_transform('gwa', 'slit_frame').slits
Only returned if return_slits is True
"""

wcs = copy.deepcopy(input_model.meta.wcs)

sca2gwa = copy.deepcopy(wcs.pipeline[1].transform[1:])
wcs.set_transform('sca', 'gwa', sca2gwa)

gwa2slit = [copy.deepcopy(wcs.pipeline[2].transform.get_model(slit_name))
for slit_name in slitnames]

slit2slicer = [copy.deepcopy(wcs.pipeline[3].transform.get_model(slit_name))
for slit_name in slitnames]

if return_slits:
g2s = wcs.get_transform('gwa', 'slit_frame')
open_slits = g2s.slits
return wcs, sca2gwa, gwa2slit, slit2slicer, copy.deepcopy(open_slits)
else:
return wcs, sca2gwa, gwa2slit, slit2slicer


def _nrs_wcs_set_input_lite(input_model, input_wcs, slit_name, transforms,
wavelength_range=None, open_slits=None,
slit_y_low=None, slit_y_high=None):

"""
Return a WCS object for a specific slit, slice or shutter
The lite version of the routine is distinguished from the legacy
routine because it does not make a deep copy of the input WCS object.
Parameters
----------
input_model : `~jwst.datamodels.JwstDataModel`
A WCS object for the all open slitlets in an observation.
input_wcs : `~gwcs.wcs.WCS`
A WCS object for the all open slitlets in an observation. This
will be modified and returned.
slit_name : int or str
Slit.name of an open slit.
transforms : list of `~astropy.modeling.core.Model`
Model transforms output from ``_get_transforms``
wavelength_range: list
Wavelength range for the combination of filter and grating. Optional.
open_slits : list of slits
List of open slits. Optional.
Returns
-------
wcsobj : `~gwcs.wcs.WCS`
WCS object for this slit.
"""


def _get_y_range(input_model, open_slits):
if open_slits is None:
log_message = 'nrs_wcs_set_input_lite must be called with open_slits if not in ifu mode'
log.critical(log_message)
raise RuntimeError(log_message)
# Need the open slits to get the slit ymin,ymax
slit = [s for s in open_slits if s.name == slit_name][0]
return slit.ymin, slit.ymax

if wavelength_range is None:
_, wavelength_range = spectral_order_wrange_from_model(input_model)

slit_wcs = copy.copy(input_wcs)

slit_wcs.set_transform('sca', 'gwa', transforms[0])
slit_wcs.set_transform('gwa', 'slit_frame', transforms[1])

is_nirspec_ifu = is_nrs_ifu_lamp(input_model) or input_model.meta.exposure.type.lower() == 'nrs_ifu'

if is_nirspec_ifu:
slit_wcs.set_transform('slit_frame', 'slicer', transforms[2] & Identity(1))
else:
slit_wcs.set_transform('slit_frame', 'msa_frame', transforms[2] & Identity(1))

transform = slit_wcs.get_transform('detector', 'slit_frame')

if is_nirspec_ifu:
bb = compute_bounding_box(transform, wavelength_range)
else:
if slit_y_low is None or slit_y_high is None:
slit_y_low, slit_y_high = _get_y_range(input_model, open_slits)
bb = compute_bounding_box(transform, wavelength_range,
slit_ymin=slit_y_low, slit_ymax=slit_y_high)

slit_wcs.bounding_box = bb
return slit_wcs


def _nrs_wcs_set_input(input_model, slit_name):
"""
Returns a WCS object for a specific slit, slice or shutter.
Expand All @@ -1713,7 +1839,7 @@ def _nrs_wcs_set_input(input_model, slit_name):
wcsobj : `~gwcs.wcs.WCS`
WCS object for this slit.
"""
import copy

wcsobj = input_model.meta.wcs

slit_wcs = copy.deepcopy(wcsobj)
Expand Down
46 changes: 28 additions & 18 deletions jwst/clean_flicker_noise/clean_flicker_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings

import gwcs
from gwcs.utils import _toindex
import numpy as np
from astropy.stats import sigma_clipped_stats, SigmaClip
from astropy.utils.exceptions import AstropyUserWarning
Expand Down Expand Up @@ -166,11 +167,14 @@ def mask_ifu_slices(input_model, mask):
# Initialize global DQ map to all zero (OK to use)
dqmap = np.zeros_like(input_model.dq)

# Get the wcs objects for all IFU slices
list_of_wcs = nirspec.nrs_ifu_wcs(input_model)
# Note: 30 in the line below is hardcoded in nirspec.nrs.ifu_wcs, which
# the line below replaces.
wcsobj, tr1, tr2, tr3 = nirspec._get_transforms(input_model, np.arange(30))

# Loop over the IFU slices, finding the valid region for each
for (k, ifu_wcs) in enumerate(list_of_wcs):
for k in range(len(tr2)):
ifu_wcs = nirspec._nrs_wcs_set_input_lite(input_model, wcsobj, k,
[tr1, tr2[k], tr3[k]])

# Construct array indexes for pixels in this slice
x, y = gwcs.wcstools.grid_from_bounding_box(
Expand Down Expand Up @@ -220,8 +224,6 @@ def mask_slits(input_model, mask):
2D output mask with additional flags for slit pixels
"""

from jwst.extract_2d.nirspec import offset_wcs

log.info("Finding slit/slitlet pixels")

# Get the slit-to-msa frame transform from the WCS object
Expand All @@ -230,9 +232,17 @@ def mask_slits(input_model, mask):
# Loop over the slits, marking all the pixels within each bounding
# box as False (do not use) in the mask.
# Note that for 3D masks (TSO mode), all planes will be set to the same value.
for slit in slit2msa.slits:
slit_wcs = nirspec.nrs_wcs_set_input(input_model, slit.name)
xlo, xhi, ylo, yhi = offset_wcs(slit_wcs)

slits = [s.name for s in slit2msa.slits]
wcsobj, tr1, tr2, tr3, open_slits = nirspec._get_transforms(input_model, slits, return_slits=True)

for k in range(len(tr2)):
slit_wcs = nirspec._nrs_wcs_set_input_lite(input_model, wcsobj, slits[k],
[tr1, tr2[k], tr3[k]],
open_slits=open_slits)

xlo, xhi = _toindex(slit_wcs.bounding_box[0])
ylo, yhi = _toindex(slit_wcs.bounding_box[1])
mask[..., ylo:yhi, xlo:xhi] = False

return mask
Expand Down Expand Up @@ -691,14 +701,14 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,
# basically copied from lib.py. Use a robust estimator for
# standard deviation, then exclude discrepant pixels and their
# four nearest neighbors from the fit.

if exclude_outliers:
med = np.median(image[mask])
std = 1.4825 * np.median(np.abs((image - med)[mask]))
outlier = mask & (np.abs(image - med) > sigrej * std)

mask = mask & (~outlier)

# also get four nearest neighbors of flagged pixels
mask[1:] = mask[1:] & (~outlier[:-1])
mask[:-1] = mask[:-1] & (~outlier[1:])
Expand All @@ -713,7 +723,7 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,

# i1 will be the first row with a nonzero element in the mask
# imax will be the last row with a nonzero element in the mask

nonzero_mask_element = np.sum(mask, axis=1) > 0

if np.sum(nonzero_mask_element) == 0:
Expand All @@ -722,7 +732,7 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,

i1 = np.amin(np.arange(mask.shape[0])[nonzero_mask_element])
imax = np.amax(np.arange(mask.shape[0])[nonzero_mask_element])

i1_vals = []
di_list = []
models = []
Expand All @@ -736,7 +746,7 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,
if (sum_mask[k] - sum_mask[i1] > npix_iter
and sum_mask[-1] - sum_mask[i1] > 1.5 * npix_iter):
break

di = k - i1

i1_vals += [i1]
Expand All @@ -747,7 +757,7 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,
# outliers section-by-section; we have to do that earlier
# over the full array to get reliable values for the mean
# and standard deviation.

if np.mean(mask[i1:i1 + di]) > minfrac:
cleaner = NSCleanSubarray(image[i1:i1 + di], mask[i1:i1 + di],
fc=fc, exclude_outliers=False)
Expand All @@ -767,9 +777,9 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,

# Step forward by half an interval so that we have
# overlapping fitting regions.

i1 += max(int(np.round(di/2)), 1)

model = np.zeros(image.shape)
tot_wgt = np.zeros(image.shape)

Expand All @@ -779,7 +789,7 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,
# Use nonzero weights everywhere so that if only one
# correction is available it gets unit weight when we
# normalize.

for i in range(len(models)):
wgt = 1.001 - np.abs(np.linspace(-1, 1, di_list[i]))[:, np.newaxis]
model[i1_vals[i]:i1_vals[i] + di_list[i]] += wgt*models[i]
Expand Down
Loading

0 comments on commit 9fdf2c1

Please sign in to comment.