Skip to content

Commit

Permalink
JP-3330: Add NIRSpec wavelength corrections to slit WCS (#8376)
Browse files Browse the repository at this point in the history
Co-authored-by: Melanie Clarke <[email protected]>
Co-authored-by: Howard Bushouse <[email protected]>
Co-authored-by: Nadia Dencheva <[email protected]>
  • Loading branch information
4 people authored Jun 7, 2024
1 parent 4abaa90 commit ddbc6a0
Show file tree
Hide file tree
Showing 9 changed files with 503 additions and 146 deletions.
23 changes: 23 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ flat_field
- Update NIRSpec flatfield code for all modes to ensure SCI=ERR=NaN wherever the
DO_NOT_USE flag is set in the DQ array. [#8463]

- Updated the NIRSpec flatfield code to use the new format of the ``wavecorr``
wavelength zero-point corrections for point sources. [#8376]

general
-------

Expand All @@ -116,6 +119,13 @@ general

- Increase minimum required scipy. [#8441]

lib
---

- Updated the ``wcs_utils.get_wavelength`` to use the new format
of the ``wavecorr`` wavelength zero-point corrections for point
sources in NIRSpec slit data. [#8376]

master_background_mos
---------------------

Expand Down Expand Up @@ -149,6 +159,13 @@ outlier_detection
to detect outliers in TSO data, with user-defined
rolling window width via the ``rolling_window_width`` parameter. [#8473]

pathloss
--------

- Updated pathloss calculations for NIRSpec fixed slit mode to use the appropriate
wavelengths for point and uniform sources if the ``wavecorr`` wavelength
zero-point corrections for point sources have been applied. [#8376]

photom
------

Expand Down Expand Up @@ -242,6 +259,12 @@ tweakreg
message and skip ``tweakreg`` step when this condition is not satisfied and
source confusion is possible during catalog matching. [#8476]

wavecorr
--------

- Changed the NIRSpec wavelength correction algorithm to include it in slit WCS
models and resampling. Fixed the sign of the wavelength corrections. [#8376]

wfss_contam
-----------

Expand Down
4 changes: 4 additions & 0 deletions docs/jwst/pathloss/description.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ Once the 1-D correction arrays have been computed, both forms of the correction
(point and uniform) are interpolated, as a function of wavelength, into
the 2-D space of the slit or IFU data and attached to the output data model
(extensions "PATHLOSS_PS" and "PATHLOSS_UN") as a record of what was computed.
For fixed slit data, if the ``wavecorr`` step has been run to provide wavelength
corrections to point sources, the corrected wavelengths will be used to
calculate the point source pathloss, whereas the uncorrected wavelengths (appropriate
for uniform sources) will be used to calculate the uniform source pathlosses.
The form of the 2-D correction (point or uniform) that's appropriate for the
data is divided into the SCI and ERR arrays and propagated into the variance
arrays of the science data.
Expand Down
22 changes: 14 additions & 8 deletions docs/jwst/wavecorr/description.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,22 @@ These are recorded in the "SRCXPOS" and "SRCYPOS" keywords in the SCI
extension header of each slitlet in a FITS product.

The ``wavecorr`` step loops over all slit instances in the input
science product and applies a wavelength correction to slits that
contain a point source. The point source determination is based on the
value of the "SRCTYPE" keyword populated for each slit by the
science product and updates the WCS models of slits that contain a point
source to include a wavelength correction. The point source determination is
based on the value of the "SRCTYPE" keyword populated for each slit by the
:ref:`srctype <srctype_step>` step. The computation of the correction is
based on the "SRCXPOS" value. A value of 0.0 indicates a perfectly centered
source, and ranges from -0.5 to +0.5 for sources at the extreme edges
of a slit. The computation uses calibration data from the ``WAVECORR``
reference file. The correction is computed as a 2-D grid of
wavelength offsets, which is applied to the original 2-D grid of
wavelengths associated with each slit.
reference file, which contains pixel shifts as a function of source position
and wavelength, and can be converted to wavelength shifts with the dispersion.
For each slit, the ``wavecorr`` step uses the average wavelengths and
dispersions in a slit (averaged across the cross-dispersion direction) to
calculate corresponding corrected wavelengths. It then uses the average
wavelengths and their corrections to generate a transform that interpolates
between "center of slit" wavelengths and corrected wavelengths. This
transformation is added to the slit WCS after the ``slit_frame`` and
produces a new wavelength corrected slit frame, ``wavecorr_frame``.

NIRSpec Fixed Slit (FS)
-----------------------
Expand All @@ -47,8 +53,8 @@ the wavelength correction is only applied to the primary slit.

The estimated position of the source within the primary slit (in the
dispersion direction) is then used in the same manner as described above
for MOS slitlets to compute offsets to be added to the nominal wavelength
grid for the primary slit.
for MOS slitlets to update the slit WCS and compute corrected wavelengths
for the primary slit.

Upon successful completion of the step, the status keyword "S_WAVCOR"
is set to "COMPLETE".
59 changes: 5 additions & 54 deletions jwst/flatfield/flat_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import math

import numpy as np
from gwcs.wcstools import grid_from_bounding_box

from stdatamodels.jwst import datamodels
from stdatamodels.jwst.datamodels import dqflags

from ..lib import reffile_utils
from ..lib import reffile_utils, wcs_utils
from ..assign_wcs import nirspec

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -1924,59 +1923,11 @@ def flat_for_nirspec_slit(slit, f_flat_model, s_flat_model, d_flat_model,
# Get the wavelength at each pixel in the extracted slit data.
# If the wavelength attribute exists and is populated, use it
# in preference to the wavelengths returned by the wcs function.
got_wl_attribute = True
try:
wl = slit.wavelength.copy() # a 2-D array
except AttributeError:
got_wl_attribute = False
if not got_wl_attribute or len(wl) == 0:
got_wl_attribute = False
return_dummy = False

# Has the use_wavecorr param been set?
if use_wavecorr is not None:
if use_wavecorr:
# Need to use the 2D wavelength array, because that's where
# the corrected wavelengths are stored
if got_wl_attribute:
# We've got the "wl" wavelength array we need
pass
else:
# Can't do the computation without the 2D wavelength array
log.error(f"The wavelength array for slit {slit.name} is not populated")
log.error("Skipping flat-field correction")
return_dummy = True
elif not use_wavecorr:
# Need to use the WCS object to create an uncorrected 2D wavelength array
if got_wcs:
log.info(f"Creating wavelength array from WCS for slit {slit.name}")
bb = slit.meta.wcs.bounding_box
grid = grid_from_bounding_box(bb)
wl = slit.meta.wcs(*grid)[2]
del grid
else:
# Can't create the uncorrected wavelengths without the WCS
log.error(f"Slit {slit.name} has no WCS object")
log.error("Skipping flat-field correction")
return_dummy = True
else:
# use_wavecorr was not specified, so use default processing
if not got_wl_attribute or np.nanmin(wl) == 0. and np.nanmax(wl) == 0.:
got_wl_attribute = False
log.warning(f"The wavelength array for slit {slit.name} has not been populated")
# Try to create it from the WCS
if got_wcs:
bb = slit.meta.wcs.bounding_box
grid = grid_from_bounding_box(bb)
wl = slit.meta.wcs(*grid)[2]
del grid
else:
log.warning("and this slit does not have a 'wcs' attribute")
log.warning("likely because assign_wcs has not been run.")
log.error("skipping ...")
return_dummy = True
else:
log.debug("Wavelengths are from the wavelength array.")
return_dummy = False
wl = wcs_utils.get_wavelengths(slit, use_wavecorr=use_wavecorr)
if wl is None:
return_dummy = True

# Create and return a dummy flat as a placeholder, if necessary
if return_dummy:
Expand Down
157 changes: 157 additions & 0 deletions jwst/lib/tests/test_wcs_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import numpy as np
from numpy.testing import assert_allclose

from astropy import units as u
from astropy import coordinates as coord
from astropy.modeling.models import Mapping, Identity, Shift, Scale
from gwcs import wcstools, wcs
from gwcs import coordinate_frames as cf

from stdatamodels.jwst import datamodels
from stdatamodels.jwst.transforms.models import NirissSOSSModel
from jwst.lib.wcs_utils import get_wavelengths
from jwst.assign_wcs import util


def create_model():

det = cf.Frame2D(name="detector", axes_order=(0, 1))

sky = cf.CelestialFrame(name="sky", axes_order=(0, 1), reference_frame=coord.ICRS())
slit_spatial = cf.Frame2D(
name="slit_spatial",
axes_order=(0, 1),
unit=("", ""),
axes_names=("x_slit", "y_slit"),
)

spec = cf.SpectralFrame(
name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",)
)
slit_frame = cf.CompositeFrame([slit_spatial, spec], name="slit_frame")
world = cf.CompositeFrame([sky, spec], name="world")

det2slit = Mapping((0, 1, 1)) | (Identity(2) & (Scale(0.5) | Shift(0.5)))
slit2sky = Identity(3)

slit_wcs = wcs.WCS([(det, det2slit), (slit_frame, slit2sky), (world, None)])

# compute wavelengths

data = np.full((10, 10), fill_value=5.0)

bounding_box = util.wcs_bbox_from_shape(data.shape)

x, y = wcstools.grid_from_bounding_box(bounding_box, step=(1, 1))
_, _, lam = slit_wcs(x, y)
lam = lam.astype(np.float32)
model = datamodels.SlitModel(data=data, wavelength=lam)
model.meta.wcs = slit_wcs

return model


def create_mock_wl():
wl = np.arange(10.0)
wl = wl[:, np.newaxis]
wl = np.repeat(wl, 10, axis=1)
wl = (wl * 0.5) + 0.5
return wl


def test_get_wavelengths():

# create a mock SlitModel
model = create_model()

# calculate what the wavelength array should be
wl_og = create_mock_wl()

# Test that the get wavelengths returns the wavelength grid
wl = get_wavelengths(model)
assert_allclose(wl, wl_og)

del model.wavelength

# Check that wavelengths can be generated from wcs when the
# wavelength attribute is unavailable
wl = get_wavelengths(model)
assert_allclose(wl, wl_og)

# Check that wavelengths are generated correctly when given a WFSS exp_type
wl = get_wavelengths(model, exp_type="NRC_TSGRISM")
assert_allclose(wl, wl_og)


def test_get_wavelengths_soss():

# create a mock SlitModel
model = create_model()

del model.wavelength
model.meta.exposure.type = "NIS_SOSS"

wcs = model.meta.wcs
new_wcs = NirissSOSSModel(
[
1,
],
[
wcs,
],
)
model.meta.wcs = new_wcs

# calculate what the wavelength array should be
wl_og = create_mock_wl()

wl = get_wavelengths(model, order=1)
assert_allclose(wl, wl_og)


def test_get_wavelength_wavecorr():

# create a mock SlitModel
model = create_model()

wl_og = create_mock_wl()

# Test use_wavecorr with no wavelength correction modificiation
# get_wavelengths should return the same wavelengths for use_wavecorr
# True and False

wl_corr = get_wavelengths(model, use_wavecorr=True)
assert_allclose(wl_corr, wl_og)

wl_uncorr = get_wavelengths(model, use_wavecorr=False)
assert_allclose(wl_corr, wl_uncorr)

# Update the model wcs to add a wavelength corrected slit frame
slit_spatial = cf.Frame2D(
name="slit_spatial",
axes_order=(0, 1),
unit=("", ""),
axes_names=("x_slit", "y_slit"),
)
spec = cf.SpectralFrame(
name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",)
)
wcorr_frame = cf.CompositeFrame([slit_spatial, spec], name="wavecorr_frame")

# Insert the new transform into the slit wcs object
wave2wavecorr = Identity(2) & Shift(0.1)
model.meta.wcs.insert_frame("slit_frame", wave2wavecorr, wcorr_frame)

bounding_box = util.wcs_bbox_from_shape(model.data.shape)
x, y = wcstools.grid_from_bounding_box(bounding_box, step=(1, 1))
_, _, lam = model.meta.wcs(x, y)
model.wavelength = lam

# calculate what the corrected wavelength array should be
wl_corr_og = wl_og + 0.1

wl_corr = get_wavelengths(model, use_wavecorr=True)
assert_allclose(wl_corr, wl_corr_og)

wl_uncorr = get_wavelengths(model, use_wavecorr=False)
assert_allclose(wl_uncorr, wl_og)
21 changes: 13 additions & 8 deletions jwst/lib/wcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,24 @@ def get_wavelengths(model, exp_type="", order=None, use_wavecorr=None):
got_wavelength = False
wl_array = None

# If we've been asked to use the corrected wavelengths stored in
# the wavelength array, return those wavelengths. Otherwise, the
# results computed from the WCS object (below) will be returned.
if use_wavecorr is not None:
if use_wavecorr:
return wl_array
else:
got_wavelength = False # force wl computation below

# Evaluate the WCS on the grid of pixel indexes, capturing only the
# resulting wavelength values
shape = model.data.shape
grid = np.indices(shape[-2:], dtype=np.float64)

# If we've been asked to use the uncorrected wavelengths we need to
# recalculate them from the wcs by skipping the transformation between
# the slit frame and the wavelength corrected slit frame. If the wavecorr_frame
# is not in the wcs assume that the wavelength correction has not been applied.
if use_wavecorr is not None:
if (not use_wavecorr and hasattr(model.meta, "wcs")
and 'wavecorr_frame' in model.meta.wcs.available_frames):
wcs = model.meta.wcs
detector2slit = wcs.get_transform('detector', 'slit_frame')
wavecorr2world = wcs.get_transform("wavecorr_frame", "world")
wl_array = (detector2slit | wavecorr2world)(grid[1], grid[0])[2]
return wl_array

# If no existing wavelength array, compute one
if hasattr(model.meta, "wcs") and not got_wavelength:
Expand Down
Loading

0 comments on commit ddbc6a0

Please sign in to comment.