From ddbc6a02796aa9776a7605f56288d9bace7e6a6d Mon Sep 17 00:00:00 2001 From: Christian Hayes Date: Fri, 7 Jun 2024 17:04:03 -0400 Subject: [PATCH] JP-3330: Add NIRSpec wavelength corrections to slit WCS (#8376) Co-authored-by: Melanie Clarke Co-authored-by: Howard Bushouse Co-authored-by: Nadia Dencheva --- CHANGES.rst | 23 ++++ docs/jwst/pathloss/description.rst | 4 + docs/jwst/wavecorr/description.rst | 22 ++-- jwst/flatfield/flat_field.py | 59 +-------- jwst/lib/tests/test_wcs_utils.py | 157 ++++++++++++++++++++++++ jwst/lib/wcs_utils.py | 21 ++-- jwst/pathloss/pathloss.py | 45 ++++--- jwst/wavecorr/tests/test_wavecorr.py | 145 ++++++++++++++++++---- jwst/wavecorr/wavecorr.py | 173 ++++++++++++++++++++------- 9 files changed, 503 insertions(+), 146 deletions(-) create mode 100644 jwst/lib/tests/test_wcs_utils.py diff --git a/CHANGES.rst b/CHANGES.rst index 2169ae6685..08c5e3fe30 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -108,6 +108,9 @@ flat_field - Update NIRSpec flatfield code for all modes to ensure SCI=ERR=NaN wherever the DO_NOT_USE flag is set in the DQ array. [#8463] +- Updated the NIRSpec flatfield code to use the new format of the ``wavecorr`` + wavelength zero-point corrections for point sources. [#8376] + general ------- @@ -116,6 +119,13 @@ general - Increase minimum required scipy. [#8441] +lib +--- + +- Updated the ``wcs_utils.get_wavelength`` to use the new format + of the ``wavecorr`` wavelength zero-point corrections for point + sources in NIRSpec slit data. [#8376] + master_background_mos --------------------- @@ -149,6 +159,13 @@ outlier_detection to detect outliers in TSO data, with user-defined rolling window width via the ``rolling_window_width`` parameter. [#8473] +pathloss +-------- + +- Updated pathloss calculations for NIRSpec fixed slit mode to use the appropriate + wavelengths for point and uniform sources if the ``wavecorr`` wavelength + zero-point corrections for point sources have been applied. [#8376] + photom ------ @@ -242,6 +259,12 @@ tweakreg message and skip ``tweakreg`` step when this condition is not satisfied and source confusion is possible during catalog matching. [#8476] +wavecorr +-------- + +- Changed the NIRSpec wavelength correction algorithm to include it in slit WCS + models and resampling. Fixed the sign of the wavelength corrections. [#8376] + wfss_contam ----------- diff --git a/docs/jwst/pathloss/description.rst b/docs/jwst/pathloss/description.rst index 2e10ad4cbe..65529b736a 100644 --- a/docs/jwst/pathloss/description.rst +++ b/docs/jwst/pathloss/description.rst @@ -45,6 +45,10 @@ Once the 1-D correction arrays have been computed, both forms of the correction (point and uniform) are interpolated, as a function of wavelength, into the 2-D space of the slit or IFU data and attached to the output data model (extensions "PATHLOSS_PS" and "PATHLOSS_UN") as a record of what was computed. +For fixed slit data, if the ``wavecorr`` step has been run to provide wavelength +corrections to point sources, the corrected wavelengths will be used to +calculate the point source pathloss, whereas the uncorrected wavelengths (appropriate +for uniform sources) will be used to calculate the uniform source pathlosses. The form of the 2-D correction (point or uniform) that's appropriate for the data is divided into the SCI and ERR arrays and propagated into the variance arrays of the science data. diff --git a/docs/jwst/wavecorr/description.rst b/docs/jwst/wavecorr/description.rst index 999ee6a6db..fa7e311d5a 100644 --- a/docs/jwst/wavecorr/description.rst +++ b/docs/jwst/wavecorr/description.rst @@ -22,16 +22,22 @@ These are recorded in the "SRCXPOS" and "SRCYPOS" keywords in the SCI extension header of each slitlet in a FITS product. The ``wavecorr`` step loops over all slit instances in the input -science product and applies a wavelength correction to slits that -contain a point source. The point source determination is based on the -value of the "SRCTYPE" keyword populated for each slit by the +science product and updates the WCS models of slits that contain a point +source to include a wavelength correction. The point source determination is +based on the value of the "SRCTYPE" keyword populated for each slit by the :ref:`srctype ` step. The computation of the correction is based on the "SRCXPOS" value. A value of 0.0 indicates a perfectly centered source, and ranges from -0.5 to +0.5 for sources at the extreme edges of a slit. The computation uses calibration data from the ``WAVECORR`` -reference file. The correction is computed as a 2-D grid of -wavelength offsets, which is applied to the original 2-D grid of -wavelengths associated with each slit. +reference file, which contains pixel shifts as a function of source position +and wavelength, and can be converted to wavelength shifts with the dispersion. +For each slit, the ``wavecorr`` step uses the average wavelengths and +dispersions in a slit (averaged across the cross-dispersion direction) to +calculate corresponding corrected wavelengths. It then uses the average +wavelengths and their corrections to generate a transform that interpolates +between "center of slit" wavelengths and corrected wavelengths. This +transformation is added to the slit WCS after the ``slit_frame`` and +produces a new wavelength corrected slit frame, ``wavecorr_frame``. NIRSpec Fixed Slit (FS) ----------------------- @@ -47,8 +53,8 @@ the wavelength correction is only applied to the primary slit. The estimated position of the source within the primary slit (in the dispersion direction) is then used in the same manner as described above -for MOS slitlets to compute offsets to be added to the nominal wavelength -grid for the primary slit. +for MOS slitlets to update the slit WCS and compute corrected wavelengths +for the primary slit. Upon successful completion of the step, the status keyword "S_WAVCOR" is set to "COMPLETE". diff --git a/jwst/flatfield/flat_field.py b/jwst/flatfield/flat_field.py index 1c4c1db6ea..01da31668d 100644 --- a/jwst/flatfield/flat_field.py +++ b/jwst/flatfield/flat_field.py @@ -6,12 +6,11 @@ import math import numpy as np -from gwcs.wcstools import grid_from_bounding_box from stdatamodels.jwst import datamodels from stdatamodels.jwst.datamodels import dqflags -from ..lib import reffile_utils +from ..lib import reffile_utils, wcs_utils from ..assign_wcs import nirspec log = logging.getLogger(__name__) @@ -1924,59 +1923,11 @@ def flat_for_nirspec_slit(slit, f_flat_model, s_flat_model, d_flat_model, # Get the wavelength at each pixel in the extracted slit data. # If the wavelength attribute exists and is populated, use it # in preference to the wavelengths returned by the wcs function. - got_wl_attribute = True - try: - wl = slit.wavelength.copy() # a 2-D array - except AttributeError: - got_wl_attribute = False - if not got_wl_attribute or len(wl) == 0: - got_wl_attribute = False - return_dummy = False - # Has the use_wavecorr param been set? - if use_wavecorr is not None: - if use_wavecorr: - # Need to use the 2D wavelength array, because that's where - # the corrected wavelengths are stored - if got_wl_attribute: - # We've got the "wl" wavelength array we need - pass - else: - # Can't do the computation without the 2D wavelength array - log.error(f"The wavelength array for slit {slit.name} is not populated") - log.error("Skipping flat-field correction") - return_dummy = True - elif not use_wavecorr: - # Need to use the WCS object to create an uncorrected 2D wavelength array - if got_wcs: - log.info(f"Creating wavelength array from WCS for slit {slit.name}") - bb = slit.meta.wcs.bounding_box - grid = grid_from_bounding_box(bb) - wl = slit.meta.wcs(*grid)[2] - del grid - else: - # Can't create the uncorrected wavelengths without the WCS - log.error(f"Slit {slit.name} has no WCS object") - log.error("Skipping flat-field correction") - return_dummy = True - else: - # use_wavecorr was not specified, so use default processing - if not got_wl_attribute or np.nanmin(wl) == 0. and np.nanmax(wl) == 0.: - got_wl_attribute = False - log.warning(f"The wavelength array for slit {slit.name} has not been populated") - # Try to create it from the WCS - if got_wcs: - bb = slit.meta.wcs.bounding_box - grid = grid_from_bounding_box(bb) - wl = slit.meta.wcs(*grid)[2] - del grid - else: - log.warning("and this slit does not have a 'wcs' attribute") - log.warning("likely because assign_wcs has not been run.") - log.error("skipping ...") - return_dummy = True - else: - log.debug("Wavelengths are from the wavelength array.") + return_dummy = False + wl = wcs_utils.get_wavelengths(slit, use_wavecorr=use_wavecorr) + if wl is None: + return_dummy = True # Create and return a dummy flat as a placeholder, if necessary if return_dummy: diff --git a/jwst/lib/tests/test_wcs_utils.py b/jwst/lib/tests/test_wcs_utils.py new file mode 100644 index 0000000000..0cd731e771 --- /dev/null +++ b/jwst/lib/tests/test_wcs_utils.py @@ -0,0 +1,157 @@ +import numpy as np +from numpy.testing import assert_allclose + +from astropy import units as u +from astropy import coordinates as coord +from astropy.modeling.models import Mapping, Identity, Shift, Scale +from gwcs import wcstools, wcs +from gwcs import coordinate_frames as cf + +from stdatamodels.jwst import datamodels +from stdatamodels.jwst.transforms.models import NirissSOSSModel +from jwst.lib.wcs_utils import get_wavelengths +from jwst.assign_wcs import util + + +def create_model(): + + det = cf.Frame2D(name="detector", axes_order=(0, 1)) + + sky = cf.CelestialFrame(name="sky", axes_order=(0, 1), reference_frame=coord.ICRS()) + slit_spatial = cf.Frame2D( + name="slit_spatial", + axes_order=(0, 1), + unit=("", ""), + axes_names=("x_slit", "y_slit"), + ) + + spec = cf.SpectralFrame( + name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) + ) + slit_frame = cf.CompositeFrame([slit_spatial, spec], name="slit_frame") + world = cf.CompositeFrame([sky, spec], name="world") + + det2slit = Mapping((0, 1, 1)) | (Identity(2) & (Scale(0.5) | Shift(0.5))) + slit2sky = Identity(3) + + slit_wcs = wcs.WCS([(det, det2slit), (slit_frame, slit2sky), (world, None)]) + + # compute wavelengths + + data = np.full((10, 10), fill_value=5.0) + + bounding_box = util.wcs_bbox_from_shape(data.shape) + + x, y = wcstools.grid_from_bounding_box(bounding_box, step=(1, 1)) + _, _, lam = slit_wcs(x, y) + lam = lam.astype(np.float32) + model = datamodels.SlitModel(data=data, wavelength=lam) + model.meta.wcs = slit_wcs + + return model + + +def create_mock_wl(): + wl = np.arange(10.0) + wl = wl[:, np.newaxis] + wl = np.repeat(wl, 10, axis=1) + wl = (wl * 0.5) + 0.5 + return wl + + +def test_get_wavelengths(): + + # create a mock SlitModel + model = create_model() + + # calculate what the wavelength array should be + wl_og = create_mock_wl() + + # Test that the get wavelengths returns the wavelength grid + wl = get_wavelengths(model) + assert_allclose(wl, wl_og) + + del model.wavelength + + # Check that wavelengths can be generated from wcs when the + # wavelength attribute is unavailable + wl = get_wavelengths(model) + assert_allclose(wl, wl_og) + + # Check that wavelengths are generated correctly when given a WFSS exp_type + wl = get_wavelengths(model, exp_type="NRC_TSGRISM") + assert_allclose(wl, wl_og) + + +def test_get_wavelengths_soss(): + + # create a mock SlitModel + model = create_model() + + del model.wavelength + model.meta.exposure.type = "NIS_SOSS" + + wcs = model.meta.wcs + new_wcs = NirissSOSSModel( + [ + 1, + ], + [ + wcs, + ], + ) + model.meta.wcs = new_wcs + + # calculate what the wavelength array should be + wl_og = create_mock_wl() + + wl = get_wavelengths(model, order=1) + assert_allclose(wl, wl_og) + + +def test_get_wavelength_wavecorr(): + + # create a mock SlitModel + model = create_model() + + wl_og = create_mock_wl() + + # Test use_wavecorr with no wavelength correction modificiation + # get_wavelengths should return the same wavelengths for use_wavecorr + # True and False + + wl_corr = get_wavelengths(model, use_wavecorr=True) + assert_allclose(wl_corr, wl_og) + + wl_uncorr = get_wavelengths(model, use_wavecorr=False) + assert_allclose(wl_corr, wl_uncorr) + + # Update the model wcs to add a wavelength corrected slit frame + slit_spatial = cf.Frame2D( + name="slit_spatial", + axes_order=(0, 1), + unit=("", ""), + axes_names=("x_slit", "y_slit"), + ) + spec = cf.SpectralFrame( + name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) + ) + wcorr_frame = cf.CompositeFrame([slit_spatial, spec], name="wavecorr_frame") + + # Insert the new transform into the slit wcs object + wave2wavecorr = Identity(2) & Shift(0.1) + model.meta.wcs.insert_frame("slit_frame", wave2wavecorr, wcorr_frame) + + bounding_box = util.wcs_bbox_from_shape(model.data.shape) + x, y = wcstools.grid_from_bounding_box(bounding_box, step=(1, 1)) + _, _, lam = model.meta.wcs(x, y) + model.wavelength = lam + + # calculate what the corrected wavelength array should be + wl_corr_og = wl_og + 0.1 + + wl_corr = get_wavelengths(model, use_wavecorr=True) + assert_allclose(wl_corr, wl_corr_og) + + wl_uncorr = get_wavelengths(model, use_wavecorr=False) + assert_allclose(wl_uncorr, wl_og) diff --git a/jwst/lib/wcs_utils.py b/jwst/lib/wcs_utils.py index 5dbfceeffb..3070756f13 100644 --- a/jwst/lib/wcs_utils.py +++ b/jwst/lib/wcs_utils.py @@ -41,19 +41,24 @@ def get_wavelengths(model, exp_type="", order=None, use_wavecorr=None): got_wavelength = False wl_array = None - # If we've been asked to use the corrected wavelengths stored in - # the wavelength array, return those wavelengths. Otherwise, the - # results computed from the WCS object (below) will be returned. - if use_wavecorr is not None: - if use_wavecorr: - return wl_array - else: - got_wavelength = False # force wl computation below # Evaluate the WCS on the grid of pixel indexes, capturing only the # resulting wavelength values shape = model.data.shape grid = np.indices(shape[-2:], dtype=np.float64) + + # If we've been asked to use the uncorrected wavelengths we need to + # recalculate them from the wcs by skipping the transformation between + # the slit frame and the wavelength corrected slit frame. If the wavecorr_frame + # is not in the wcs assume that the wavelength correction has not been applied. + if use_wavecorr is not None: + if (not use_wavecorr and hasattr(model.meta, "wcs") + and 'wavecorr_frame' in model.meta.wcs.available_frames): + wcs = model.meta.wcs + detector2slit = wcs.get_transform('detector', 'slit_frame') + wavecorr2world = wcs.get_transform("wavecorr_frame", "world") + wl_array = (detector2slit | wavecorr2world)(grid[1], grid[0])[2] + return wl_array # If no existing wavelength array, compute one if hasattr(model.meta, "wcs") and not got_wavelength: diff --git a/jwst/pathloss/pathloss.py b/jwst/pathloss/pathloss.py index 4ad2f551b8..9586bc320e 100644 --- a/jwst/pathloss/pathloss.py +++ b/jwst/pathloss/pathloss.py @@ -976,24 +976,41 @@ def _corrections_for_fixedslit(slit, pathloss, exp_type, source_type): wavelength_pointsource *= 1.0e6 wavelength_uniformsource *= 1.0e6 - wavelength_array = slit.wavelength - - # Compute the point source pathloss 2D correction - pathloss_2d_ps = interpolate_onto_grid( - wavelength_array, - wavelength_pointsource, - pathloss_pointsource_vector) - - # Compute the uniform source pathloss 2D correction - pathloss_2d_un = interpolate_onto_grid( - wavelength_array, - wavelength_uniformsource, - pathloss_uniform_vector) - # Use the appropriate correction for this slit if is_pointsource(source_type or slit.source_type): + # calculate the point source corrected wavelengths and uncorrected wavelengths for the slit + wavelength_array_corr = get_wavelengths(slit, use_wavecorr=True) + wavelength_array_uncorr = get_wavelengths(slit, use_wavecorr=False) + + # Compute the point source pathloss 2D correction + pathloss_2d_ps = interpolate_onto_grid( + wavelength_array_corr, + wavelength_pointsource, + pathloss_pointsource_vector) + + # Compute the uniform source pathloss 2D correction + pathloss_2d_un = interpolate_onto_grid( + wavelength_array_uncorr, + wavelength_uniformsource, + pathloss_uniform_vector) + pathloss_2d = pathloss_2d_ps + else: + wavelength_array = slit.wavelength + + # Compute the point source pathloss 2D correction + pathloss_2d_ps = interpolate_onto_grid( + wavelength_array, + wavelength_pointsource, + pathloss_pointsource_vector) + + # Compute the uniform source pathloss 2D correction + pathloss_2d_un = interpolate_onto_grid( + wavelength_array, + wavelength_uniformsource, + pathloss_uniform_vector) + pathloss_2d = pathloss_2d_un # Save the corrections. The `data` portion is the correction used. diff --git a/jwst/wavecorr/tests/test_wavecorr.py b/jwst/wavecorr/tests/test_wavecorr.py index e4a06cb625..5ff61fc241 100644 --- a/jwst/wavecorr/tests/test_wavecorr.py +++ b/jwst/wavecorr/tests/test_wavecorr.py @@ -27,32 +27,60 @@ def test_wavecorr(): im = datamodels.ImageModel(hdul) im_wcs = AssignWcsStep.call(im) im_ex2d = Extract2dStep.call(im_wcs) - im_ex2d.slits[0].meta.wcs.bounding_box = ((-.5, 1432.5), (-.5, 37.5)) + bbox = ((-.5, 1432.5), (-.5, 37.5)) + im_ex2d.slits[0].meta.wcs.bounding_box = bbox + x, y = wcstools.grid_from_bounding_box(bbox) + ra, dec, lam_before = im_ex2d.slits[0].meta.wcs(x, y) + im_ex2d.slits[0].wavelength = lam_before im_src = SourceTypeStep.call(im_ex2d) + + # the mock msa source is an extended source, change to point for testing + im_src.slits[0].source_type = 'POINT' im_wave = WavecorrStep.call(im_src) # test dispersion is of the correct order # there's one slit only slit = im_src.slits[0] - x, y = wcstools.grid_from_bounding_box(slit.meta.wcs.bounding_box) dispersion = wavecorr.compute_dispersion(slit.meta.wcs, x, y) assert_allclose(dispersion[~np.isnan(dispersion)], 1e-9, atol=1e-10) - - # the difference in wavelength should be of the order of e-10 in um - assert_allclose(im_src.slits[0].wavelength - im_wave.slits[0].wavelength, 1e-10) - + + # test that the wavelength is on the order of microns + wavelength = wavecorr.compute_wavelength(slit.meta.wcs, x, y) + assert_allclose(np.nanmean(wavelength), 2.5, atol=0.1) + + # Check that the stored wavelengths were corrected + abs_wave_correction = np.abs(im_src.slits[0].wavelength - im_wave.slits[0].wavelength) + assert_allclose(np.nanmean(abs_wave_correction), 0.00046, atol=0.0001) + + # Check that the slit wcs has been updated to provide corrected wavelengths + corrected_wavelength = wavecorr.compute_wavelength(im_wave.slits[0].meta.wcs, x, y) + assert_allclose(im_wave.slits[0].wavelength, corrected_wavelength) + + # test the round-tripping on the wavelength correction transform + ref_name = im_wave.meta.ref_file.wavecorr.name + freference = datamodels.WaveCorrModel( + WavecorrStep.reference_uri_to_cache_path(ref_name, im.crds_observatory)) + + lam_uncorr = lam_before * 1e-6 + wave2wavecorr = wavecorr.calculate_wavelength_correction_transform( + lam_uncorr, dispersion, freference, slit.source_xpos, 'MOS') + lam_corr = wave2wavecorr(lam_uncorr) + assert_allclose(lam_uncorr, wave2wavecorr.inverse(lam_corr)) + # test on both sides of the shutter source_xpos1 = -.2 source_xpos2 = .2 - ra, dec, lam = slit.meta.wcs(x, y) - ref_name = im_wave.meta.ref_file.wavecorr.name - freference = datamodels.WaveCorrModel(WavecorrStep.reference_uri_to_cache_path(ref_name, im.crds_observatory)) - zero_point1 = wavecorr.compute_zero_point_correction(lam, freference, source_xpos1, 'MOS', dispersion) - zero_point2 = wavecorr.compute_zero_point_correction(lam, freference, source_xpos2, 'MOS', dispersion) - diff_correction = np.abs(zero_point1[1] - zero_point2[1]) - non_zero = diff_correction[diff_correction != 0] - assert_allclose(np.nanmean(non_zero), 0.75, atol=0.01) + wave_transform1 = wavecorr.calculate_wavelength_correction_transform( + lam_uncorr, dispersion, freference, source_xpos1, 'MOS') + wave_transform2 = wavecorr.calculate_wavelength_correction_transform( + lam_uncorr, dispersion, freference, source_xpos2, 'MOS') + + zero_point1 = wave_transform1(lam_uncorr) + zero_point2 = wave_transform2(lam_uncorr) + + diff_correction = np.abs(zero_point1 - zero_point2) + assert_allclose(np.nanmean(diff_correction), 8.0e-10, atol=0.1e-10) def test_ideal_to_v23_fs(): @@ -119,7 +147,64 @@ def test_skipped(): outw = WavecorrStep.call(outs) source_pos = (0.004938526981283373, -0.02795306204991911) - assert_allclose((outw.slits[ind].source_xpos, outw.slits[ind].source_ypos), source_pos) + assert_allclose((outw.slits[ind].source_xpos, outw.slits[ind].source_ypos),source_pos) + + # Test if the corrected wavelengths are not monotonically increasing + + # This case is not expected with real data, test that no correction + # transform is returned (a skip criterion) with simple case + # of flipped wavelength solutions, which produces a monotonically + # decreasing wavelengths + lam = np.tile(np.flip(np.arange(0.6, 5.5, 0.01)*1e-6), (22, 1)) + disp = np.tile(np.full(lam.shape[-1], -0.01)*1e-6, (22, 1)) + + ref_name = outw.meta.ref_file.wavecorr.name + reffile = datamodels.WaveCorrModel( + WavecorrStep.reference_uri_to_cache_path(ref_name, im.crds_observatory)) + source_xpos = 0.1 + aperture_name = 'S200A1' + + transform = wavecorr.calculate_wavelength_correction_transform( + lam, disp, reffile, source_xpos, aperture_name) + assert transform is None + + +def test_mos_slit_status(): + """ Test conditions that are skipped for mos slitlets.""" + + hdul = create_nirspec_mos_file() + msa_meta = os.path.join(jwst.__path__[0], *['assign_wcs', 'tests', 'data', 'msa_configuration.fits']) + hdul[0].header['MSAMETFL'] = msa_meta + hdul[0].header['MSAMETID'] = 12 + im = datamodels.ImageModel(hdul) + im_wcs = AssignWcsStep.call(im) + im_ex2d = Extract2dStep.call(im_wcs) + bbox = ((-.5, 1432.5), (-.5, 37.5)) + im_ex2d.slits[0].meta.wcs.bounding_box = bbox + x, y = wcstools.grid_from_bounding_box(bbox) + ra, dec, lam_before = im_ex2d.slits[0].meta.wcs(x, y) + im_ex2d.slits[0].wavelength = lam_before + im_src = SourceTypeStep.call(im_ex2d) + + # test the mock msa source as an extended source + im_src.slits[0].source_type = 'EXTENDED' + im_wave = WavecorrStep.call(im_src) + + # check that the step is recorded as completed + assert im_wave.meta.cal_step.wavecorr == 'COMPLETE' + + # check that the step is listed as skipped for extended mos sources + assert im_wave.slits[0].meta.cal_step.wavecorr == 'SKIPPED' + + # test the mock msa source as a point source + im_src.slits[0].source_type = 'POINT' + im_wave = WavecorrStep.call(im_src) + + # check that the step is recorded as completed + assert im_wave.meta.cal_step.wavecorr == 'COMPLETE' + + # check that the step is listed as complete for mos point sources + assert im_wave.slits[0].meta.cal_step.wavecorr == 'COMPLETE' def test_wavecorr_fs(): @@ -171,13 +256,31 @@ def test_wavecorr_fs(): dispersion = wavecorr.compute_dispersion(slit.meta.wcs, x, y) assert_allclose(dispersion[~np.isnan(dispersion)], 1e-8, atol=1.04e-8) + # Check that the slit wavelengths are consistent with the slit wcs + corrected_wavelength = wavecorr.compute_wavelength(slit.meta.wcs, x, y) + assert_allclose(slit.wavelength, corrected_wavelength) + + # test the roundtripping on the wavelength correction transform + ref_name = result.meta.ref_file.wavecorr.name + freference = datamodels.WaveCorrModel(WavecorrStep.reference_uri_to_cache_path(ref_name, im.crds_observatory)) + + lam_uncorr = lam_before * 1e-6 + wave2wavecorr = wavecorr.calculate_wavelength_correction_transform(lam_uncorr, dispersion, + freference, slit.source_xpos, 'S200A1') + lam_corr = wave2wavecorr(lam_uncorr) + assert_allclose(lam_uncorr, wave2wavecorr.inverse(lam_corr)) + # test on both sides of the slit center source_xpos1 = -.2 source_xpos2 = .2 - ref_name = result.meta.ref_file.wavecorr.name - freference = datamodels.WaveCorrModel(WavecorrStep.reference_uri_to_cache_path(ref_name, im.crds_observatory)) - zero_point1 = wavecorr.compute_zero_point_correction(lam_before, freference, source_xpos1, 'S200A1', dispersion) - zero_point2 = wavecorr.compute_zero_point_correction(lam_before, freference, source_xpos2, 'S200A1', dispersion) - diff_correction = np.abs(zero_point1[1] - zero_point2[1]) - assert_allclose(np.nanmean(diff_correction), 0.45, atol=0.01) + wave_transform1 = wavecorr.calculate_wavelength_correction_transform( + lam_uncorr, dispersion, freference, source_xpos1, 'S200A1') + wave_transform2 = wavecorr.calculate_wavelength_correction_transform( + lam_uncorr, dispersion, freference, source_xpos2, 'S200A1') + + zero_point1 = wave_transform1(lam_uncorr) + zero_point2 = wave_transform2(lam_uncorr) + + diff_correction = np.abs(zero_point1 - zero_point2) + assert_allclose(np.nanmean(diff_correction), 6.3e-9, atol=0.1e-9) diff --git a/jwst/wavecorr/wavecorr.py b/jwst/wavecorr/wavecorr.py index 0174eb1f59..46a92513b2 100644 --- a/jwst/wavecorr/wavecorr.py +++ b/jwst/wavecorr/wavecorr.py @@ -30,6 +30,10 @@ import logging import numpy as np from gwcs import wcstools +from gwcs import coordinate_frames as cf +from astropy import units as u +from astropy.modeling import tabular +from astropy.modeling.mappings import Identity from stdatamodels.jwst import datamodels from stdatamodels.jwst.transforms import models as trmodels @@ -89,16 +93,31 @@ def do_correction(input_model, wavecorr_file): input_model.meta.cal_step.wavecorr = 'SKIPPED' break if _is_point_source(slit, exp_type): - apply_zero_point_correction(slit, wavecorr_file) - output_model.meta.cal_step.wavecorr = 'COMPLETE' + completed = apply_zero_point_correction(slit, wavecorr_file) + if completed: + output_model.meta.cal_step.wavecorr = 'COMPLETE' + else: # pragma: no cover + log.warning(f'Corrections are not invertible for slit {slit.name}') + log.warning('Skipping wavecorr correction') + output_model.meta.cal_step.wavecorr = 'SKIPPED' + break # For MOS work on all slits containing a point source else: for slit in output_model.slits: if _is_point_source(slit, exp_type): - apply_zero_point_correction(slit, wavecorr_file) - output_model.meta.cal_step.wavecorr = 'COMPLETE' + completed = apply_zero_point_correction(slit, wavecorr_file) + if completed: + slit.meta.cal_step.wavecorr = 'COMPLETE' + else: # pragma: no cover + log.warning(f'Corrections are not invertible for slit {slit.name}') + log.warning('Skipping wavecorr correction') + slit.meta.cal_step.wavecorr = 'SKIPPED' + else: + slit.meta.cal_step.wavecorr = 'SKIPPED' + + output_model.meta.cal_step.wavecorr = 'COMPLETE' return output_model @@ -112,6 +131,11 @@ def apply_zero_point_correction(slit, reffile): Slit data to be corrected. reffile : str The ``wavecorr`` reference file. + + Returns + ------- + completed : bool + A flag to report whether the zero-point correction was added or skipped. """ log.info(f'slit name {slit.name}') slit_wcs = slit.meta.wcs @@ -127,67 +151,108 @@ def apply_zero_point_correction(slit, reffile): # For the MSA the aperture name is "MOS" aperture_name = "MOS" - lam = slit.wavelength.copy() + lam = slit.wavelength.copy() * 1e-6 dispersion = compute_dispersion(slit.meta.wcs) - corr, dq_lam = compute_zero_point_correction(lam, reffile, source_xpos, - aperture_name, dispersion) - # TODO: set a DQ flag to a TBD value for pixels where dq_lam == 0. - # The only purpose of dq_lam is to set that flag. - - # Wavelength is in um, the correction is computed in meters. - slit.wavelength = slit.wavelength - corr * 10 ** 6 - -def compute_zero_point_correction(lam, freference, source_xpos, - aperture_name, dispersion): - """ Compute the NIRSpec wavelength zero-point correction. + wave2wavecorr = calculate_wavelength_correction_transform(lam, + dispersion, + reffile, + source_xpos, + aperture_name) + # wave2wavecorr should not be None for real data + if wave2wavecorr is None: # pragma: no cover + completed = False + return completed + else: + # Make a new frame to insert into the slit wcs object + slit_spatial = cf.Frame2D(name='slit_spatial', axes_order=(0, 1), + unit=("", ""), axes_names=('x_slit', 'y_slit')) + spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), + axes_names=('wavelength',)) + wcorr_frame = cf.CompositeFrame( + [slit_spatial, spec], name='wavecorr_frame') + + # Insert the new transform into the slit wcs object + wave2wavecorr = Identity(2) & wave2wavecorr + slit_wcs.insert_frame('slit_frame', wave2wavecorr, wcorr_frame) + + # Update the stored wavelengths for the slit + slit.wavelength = compute_wavelength(slit_wcs) + + completed = True + return completed + + +def calculate_wavelength_correction_transform(lam, dispersion, freference, + source_xpos, aperture_name): + """ Generate a WCS transform for the NIRSpec wavelength zero-point correction + and add it to the WCS for each slit. Parameters ---------- lam : ndarray - Wavelength array. + Wavelength array [in m]. + dispersion : ndarray + The pixel dispersion [in m]. freference : str ``wavecorr`` reference file name. source_xpos : float X position of the source as a fraction of the slit size. aperture_name : str Aperture name. - dispersion : ndarray - The pixel dispersion [in m]. - + Returns ------- - lambda_corr : ndarray - Wavelength correction. - lam : ndarray - Interpolated wavelengths. Extrapolated values are reset to 0. - This is returned so that the DQ array can be updated with a flag - which indicates that no zero-point correction was done. + model : `~astropy.modeling.tabular.Tabular1D`or None + A model which takes wavelength inputs and returns zero-point + corrected wavelengths. Returns None if an invertible model + cannot be generated. """ + # Open the zero point reference model with datamodels.WaveCorrModel(freference) as wavecorr: for ap in wavecorr.apertures: if ap.aperture_name == aperture_name: log.info(f'Using wavelength zero-point correction for aperture {ap.aperture_name}') offset_model = ap.zero_point_offset.copy() - # TODO: implement variance - # variance = ap.variance.copy() - # width = ap.width break else: log.info(f'No wavelength zero-point correction found for slit {aperture_name}') - - deltax = source_xpos - lam = lam.copy() - lam_no_nans = lam[~np.isnan(lam)] + + # Set lookup table to extrapolate at bounds to recover wavelengths + # beyond model bounds, particularly for the red and blue ends of + # prism observations. fill_value = None sets the lookup tables + # to use the default extrapolation which is a linear extrapolation + # from scipy.interpolate.interpn offset_model.bounds_error = False - correction = offset_model(lam_no_nans * 10 ** -6, [deltax] * lam_no_nans.size) - lam[~np.isnan(lam)] = correction - - # The correction for pixels outside the slit and wavelengths - # outside the wave_range is 0. - lam[np.isnan(lam)] = 0. - lambda_cor = dispersion * lam - return lambda_cor, lam + offset_model.fill_value = None + + # Average the wavelength and dispersion across 2D extracted slit and remove nans + # So that we have a 1D wavelength array for building a 1D lookup table wcs transform + lam_mean = np.nanmean(lam, axis=0) + disp_mean = np.nanmean(dispersion, axis=0) + nan_lams = np.isnan(lam_mean) | np.isnan(disp_mean) + lam_mean = lam_mean[~nan_lams] + disp_mean = disp_mean[~nan_lams] + + # Calculate the corrected wavelengths + pixel_corrections = offset_model(lam_mean, source_xpos) + lam_corrected = lam_mean + (pixel_corrections * disp_mean) + + # Check to make sure that the corrected wavelengths are monotonically increasing + if np.all(np.diff(lam_corrected) > 0): + # monotonically increasing + # Build a look up table to transform between corrected and uncorrected wavelengths + wave2wavecorr = tabular.Tabular1D(points=lam_mean, + lookup_table=lam_corrected, + bounds_error=False, + fill_value=None, + name='wave2wavecorr') + + return wave2wavecorr + + else: + # output wavelengths are not monotonically increasing + return None def compute_dispersion(wcs, xpix=None, ypix=None): @@ -218,6 +283,32 @@ def compute_dispersion(wcs, xpix=None, ypix=None): return (lamright - lamleft) * 10 ** -6 +def compute_wavelength(wcs, xpix=None, ypix=None): + """ Compute the pixel wavelength. + + Parameters + ---------- + wcs : `~gwcs.wcs.WCS` + The WCS object for this slit. + xpix : ndarray, float, optional + ypix : ndarray, float, optional + Compute the wavelength at the x, y pixels. + If not provided the dispersion is computed on a + grid based on ``wcs.bounding_box``. + + Returns + ------- + wavelength : ndarray + The wavelength [in microns]. + + """ + if xpix is None or ypix is None: + xpix, ypix = wcstools.grid_from_bounding_box(wcs.bounding_box, step=(1, 1)) + + _, _, lam = wcs(xpix, ypix) + return lam + + def _is_point_source(slit, exp_type): """ Determine if a source is a point source.