diff --git a/webbpsf/webbpsf_core.py b/webbpsf/webbpsf_core.py index ed536cd2..292f6f1c 100644 --- a/webbpsf/webbpsf_core.py +++ b/webbpsf/webbpsf_core.py @@ -1813,6 +1813,7 @@ def load_wss_opd_by_date(self, date=None, choice='closest', verbose=True, plot=F opd_fn = webbpsf.mast_wss.get_opd_at_time(date, verbose=verbose, choice=choice, **kwargs) self.load_wss_opd(opd_fn, verbose=verbose, plot=plot, **kwargs) + @poppy.utils.quantity_input(wavelengths=units.meter) def calc_datacube_fast(self, wavelengths, compare_methods=False, outfile=None, add_distortion=True, *args, **kwargs): """Calculate a spectral datacube of PSFs: Simplified, much MUCH faster version. @@ -1869,7 +1870,7 @@ def calc_datacube_fast(self, wavelengths, compare_methods=False, outfile=None, # Set up cube and initialize structure based on PSF at a representative wavelength _log.info('Starting fast/simplified multiwavelength data cube calculation.') ref_wave = np.mean(wavelengths) - MIN_REF_WAVE = 2e-6 # This must not be too short, to avoid phase wrapping for the C3 bump + MIN_REF_WAVE = 2e-6 * units.meter # This must not be too short, to avoid phase wrapping for the C3 bump if ref_wave < MIN_REF_WAVE: ref_wave = MIN_REF_WAVE log_message = ( @@ -1897,7 +1898,7 @@ def calc_datacube_fast(self, wavelengths, compare_methods=False, outfile=None, ext = 0 cubefast[ext].data = np.zeros((nwavelengths, psf[ext].data.shape[0], psf[ext].data.shape[1])) cubefast[ext].data[0] = psf[ext].data - cubefast[ext].header[label_wavelength(nwavelengths, 0)] = wavelengths[0] + cubefast[ext].header[label_wavelength(nwavelengths, 0)] = wavelengths[0].to_value(units.meter) # Fast way. Assumes wavelength-independent phase and amplitude at the exit pupil!! if compare_methods: @@ -1928,7 +1929,7 @@ def calc_datacube_fast(self, wavelengths, compare_methods=False, outfile=None, wl = wavelengths[i] psfw = quickosys.calc_psf(wavelength=wl, normalize='None') cubefast[0].data[i] = psfw[0].data - cubefast[ext].header[label_wavelength(nwavelengths, i)] = wavelengths[i] + cubefast[ext].header[label_wavelength(nwavelengths, i)] = wavelengths[i].to_value(units.meter) cubefast[0].header['NWAVES'] = nwavelengths @@ -1945,7 +1946,7 @@ def calc_datacube_fast(self, wavelengths, compare_methods=False, outfile=None, for ext in range(len(psf)): cube[ext].data = np.zeros((nwavelengths, psf[ext].data.shape[0], psf[ext].data.shape[1])) cube[ext].data[0] = psf[ext].data - cube[ext].header[label_wavelength(nwavelengths, 0)] = wavelengths[0] + cube[ext].header[label_wavelength(nwavelengths, 0)] = wavelengths[0].to_value(units.meter) # iterate rest of wavelengths print('Running standard way') @@ -1954,7 +1955,7 @@ def calc_datacube_fast(self, wavelengths, compare_methods=False, outfile=None, psf = self.calc_psf(*args, monochromatic=wl, **kwargs) for ext in range(len(psf)): cube[ext].data[i] = psf[ext].data - cube[ext].header[label_wavelength(nwavelengths, i)] = wl + cube[ext].header[label_wavelength(nwavelengths, i)] = wl.to_value(units.meter) cube[ext].header.add_history('--- Cube Plane {} ---'.format(i)) for h in psf[ext].header['HISTORY']: cube[ext].header.add_history(h)