Skip to content

Commit

Permalink
Merge pull request #879 from mperrin/fix_unit_bug_in_fast_datacube
Browse files Browse the repository at this point in the history
more strict units handling; fixes some issues for astropy 6.0.0 compatibility
  • Loading branch information
obi-wan76 authored Jul 8, 2024
2 parents 23c180e + 15f5374 commit 89fcb26
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions webbpsf/webbpsf_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,6 +1813,7 @@ def load_wss_opd_by_date(self, date=None, choice='closest', verbose=True, plot=F
opd_fn = webbpsf.mast_wss.get_opd_at_time(date, verbose=verbose, choice=choice, **kwargs)
self.load_wss_opd(opd_fn, verbose=verbose, plot=plot, **kwargs)

@poppy.utils.quantity_input(wavelengths=units.meter)
def calc_datacube_fast(self, wavelengths, compare_methods=False, outfile=None,
add_distortion=True, *args, **kwargs):
"""Calculate a spectral datacube of PSFs: Simplified, much MUCH faster version.
Expand Down Expand Up @@ -1869,7 +1870,7 @@ def calc_datacube_fast(self, wavelengths, compare_methods=False, outfile=None,
# Set up cube and initialize structure based on PSF at a representative wavelength
_log.info('Starting fast/simplified multiwavelength data cube calculation.')
ref_wave = np.mean(wavelengths)
MIN_REF_WAVE = 2e-6 # This must not be too short, to avoid phase wrapping for the C3 bump
MIN_REF_WAVE = 2e-6 * units.meter # This must not be too short, to avoid phase wrapping for the C3 bump
if ref_wave < MIN_REF_WAVE:
ref_wave = MIN_REF_WAVE
log_message = (
Expand Down Expand Up @@ -1897,7 +1898,7 @@ def calc_datacube_fast(self, wavelengths, compare_methods=False, outfile=None,
ext = 0
cubefast[ext].data = np.zeros((nwavelengths, psf[ext].data.shape[0], psf[ext].data.shape[1]))
cubefast[ext].data[0] = psf[ext].data
cubefast[ext].header[label_wavelength(nwavelengths, 0)] = wavelengths[0]
cubefast[ext].header[label_wavelength(nwavelengths, 0)] = wavelengths[0].to_value(units.meter)

# Fast way. Assumes wavelength-independent phase and amplitude at the exit pupil!!
if compare_methods:
Expand Down Expand Up @@ -1928,7 +1929,7 @@ def calc_datacube_fast(self, wavelengths, compare_methods=False, outfile=None,
wl = wavelengths[i]
psfw = quickosys.calc_psf(wavelength=wl, normalize='None')
cubefast[0].data[i] = psfw[0].data
cubefast[ext].header[label_wavelength(nwavelengths, i)] = wavelengths[i]
cubefast[ext].header[label_wavelength(nwavelengths, i)] = wavelengths[i].to_value(units.meter)

cubefast[0].header['NWAVES'] = nwavelengths

Expand All @@ -1945,7 +1946,7 @@ def calc_datacube_fast(self, wavelengths, compare_methods=False, outfile=None,
for ext in range(len(psf)):
cube[ext].data = np.zeros((nwavelengths, psf[ext].data.shape[0], psf[ext].data.shape[1]))
cube[ext].data[0] = psf[ext].data
cube[ext].header[label_wavelength(nwavelengths, 0)] = wavelengths[0]
cube[ext].header[label_wavelength(nwavelengths, 0)] = wavelengths[0].to_value(units.meter)

# iterate rest of wavelengths
print('Running standard way')
Expand All @@ -1954,7 +1955,7 @@ def calc_datacube_fast(self, wavelengths, compare_methods=False, outfile=None,
psf = self.calc_psf(*args, monochromatic=wl, **kwargs)
for ext in range(len(psf)):
cube[ext].data[i] = psf[ext].data
cube[ext].header[label_wavelength(nwavelengths, i)] = wl
cube[ext].header[label_wavelength(nwavelengths, i)] = wl.to_value(units.meter)
cube[ext].header.add_history('--- Cube Plane {} ---'.format(i))
for h in psf[ext].header['HISTORY']:
cube[ext].header.add_history(h)
Expand Down

0 comments on commit 89fcb26

Please sign in to comment.