From 4b977da5d97852c20de4066a535b53f76e959119 Mon Sep 17 00:00:00 2001 From: Leo Singer Date: Thu, 5 Sep 2024 12:26:47 -0400 Subject: [PATCH] Move input unit conversion and bounds checking to base class --- CHANGES.rst | 11 +- docs/dust_extinction/dev_model.rst | 2 +- dust_extinction/averages.py | 145 ++++++--------------- dust_extinction/baseclasses.py | 46 +++++-- dust_extinction/helpers.py | 40 +----- dust_extinction/parameter_averages.py | 135 ++++++------------- dust_extinction/shapes.py | 63 +++------ dust_extinction/tests/test_ccm89.py | 1 - dust_extinction/tests/test_corvals.py | 1 - dust_extinction/tests/test_corvals_aves.py | 1 - dust_extinction/tests/test_fm90.py | 12 -- dust_extinction/tests/test_g16.py | 1 - dust_extinction/tests/test_gcc09.py | 1 - dust_extinction/tests/test_p92.py | 32 ----- dust_extinction/tests/test_vcg04.py | 1 - 15 files changed, 152 insertions(+), 340 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 160699a..118d2c1 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,7 +1,16 @@ 1.6.dev (unreleased) ================ -- none yet +- Refactor model input unit conversion and wavenumber bounds checking: + + - The models' evaluate methods are no longer meant to be called directly. + + - The BaseExtModel class now overrides the prepare_inputs method so that + unit conversion and bounds checking is done prior to calling the subclass' + evaluate method. Subclasses are no longer responsible for doing the unit + conversion and bounds checking themselves. + + - Use Astropy's built-in input model units handling. 1.5 (2024-08-16) ================ diff --git a/docs/dust_extinction/dev_model.rst b/docs/dust_extinction/dev_model.rst index 020f7ee..d6f9fd6 100644 --- a/docs/dust_extinction/dev_model.rst +++ b/docs/dust_extinction/dev_model.rst @@ -16,7 +16,7 @@ All All dust extinction models have at least the following: * A member variable `x_range` that that define the valid range of wavelengths. These are defined in inverse microns as is common for extinction curve research. -* A member function `evaluate` that computes the extinction at a given `x` and any model parameter values. The `x` values are checked to be within the valid `x_range`. The `x` values should have astropy.units. If they do not, then they are assumed to be in inverse microns and a warning is issued stating such. +* A member function `evaluate` that computes the extinction at a given `x` and any model parameter values. The `x` values are checked to be within the valid `x_range`. The `x` values passed to the `evaluate` method have no units; the base class `BaseExtModel` will automatically convert whatever units the user provided to inverse microns prior to calling the `evaulate` method. The `evaluate` method should not be called directly. All of these classes used in ``dust_extinction`` are based on the `Model `_ astropy.modeling class. diff --git a/dust_extinction/averages.py b/dust_extinction/averages.py index 0135b8c..6198874 100644 --- a/dust_extinction/averages.py +++ b/dust_extinction/averages.py @@ -5,7 +5,6 @@ from astropy.table import Table from astropy.modeling.models import PowerLaw1D -from .helpers import _get_x_in_wavenumbers, _test_valid_x_range from .baseclasses import BaseExtModel from .shapes import P92, G21, _curve_F99_method @@ -92,13 +91,14 @@ class RL85_MWGC(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 1e-6 - def evaluate(self, in_x): + @classmethod + def evaluate(cls, x): r""" RL85 MWGC function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -114,14 +114,9 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function using simple linear interpolation # avoids negative values of alav that happens with cubic splines - f = interp1d(self.obsdata_x, self.obsdata_axav) + f = interp1d(cls.obsdata_x, cls.obsdata_axav) return f(x) @@ -191,13 +186,14 @@ class RRP89_MWGC(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 1e-6 - def evaluate(self, in_x): + @classmethod + def evaluate(cls, x): r""" RRP89 MWGC function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -213,14 +209,9 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function using simple linear interpolation # avoids negative values of alav that happens with cubic splines - f = interp1d(self.obsdata_x, self.obsdata_axav) + f = interp1d(cls.obsdata_x, cls.obsdata_axav) return f(x) @@ -293,13 +284,14 @@ class B92_MWAvg(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 6e-3 - def evaluate(self, in_x): + @classmethod + def evaluate(cls, x): """ B92 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -315,14 +307,8 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.name) - # define the function allowing for spline interpolation - f = interp1d(self.obsdata_x, self.obsdata_axav) + f = interp1d(cls.obsdata_x, cls.obsdata_axav) return f(x) @@ -409,13 +395,13 @@ class G03_SMCBar(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 6e-2 - def evaluate(self, in_x): + def evaluate(self, x): """ G03 SMCBar function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -447,7 +433,7 @@ def evaluate(self, in_x): # return A(x)/A(V) return _curve_F99_method( - in_x, + x, self.Rv, C1, C2, @@ -457,8 +443,6 @@ def evaluate(self, in_x): gamma, optnir_axav_x, optnir_axav_y, - self.x_range, - self.__class__.__name__, ) @@ -540,13 +524,13 @@ class G03_LMCAvg(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 6e-2 - def evaluate(self, in_x): + def evaluate(self, x): """ G03 LMCAvg function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -576,7 +560,7 @@ def evaluate(self, in_x): # return A(x)/A(V) return _curve_F99_method( - in_x, + x, self.Rv, C1, C2, @@ -586,8 +570,6 @@ def evaluate(self, in_x): gamma, optnir_axav_x, optnir_axav_y, - self.x_range, - self.__class__.__name__, ) @@ -672,13 +654,13 @@ class G03_LMC2(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 6e-2 - def evaluate(self, in_x): + def evaluate(self, x): """ G03 LMC2 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -708,7 +690,7 @@ def evaluate(self, in_x): # return A(x)/A(V) return _curve_F99_method( - in_x, + x, self.Rv, C1, C2, @@ -718,8 +700,6 @@ def evaluate(self, in_x): gamma, optnir_axav_x, optnir_axav_y, - self.x_range, - self.__class__.__name__, ) @@ -789,13 +769,13 @@ class I05_MWAvg(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 1e-6 - def evaluate(self, in_x): + def evaluate(self, x): """ I05 MWAvg function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -811,11 +791,6 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function allowing for spline interpolation f = interp1d(self.obsdata_x, self.obsdata_axav) @@ -891,13 +866,13 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ CT06 MWGC function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -913,11 +888,6 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function allowing for spline interpolation f = interp1d(self.obsdata_x, self.obsdata_axav) @@ -993,13 +963,13 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ CG06 MWLoc function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -1015,11 +985,6 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function allowing for spline interpolation f = interp1d(self.obsdata_x, self.obsdata_axav) @@ -1145,13 +1110,13 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ GCC09_MWAvg function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -1167,11 +1132,6 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # P92 parameters fit to the data using uncs as weights p92_fit = P92( BKG_amp=203.805939127, @@ -1201,7 +1161,7 @@ def evaluate(self, in_x): ) # return A(x)/A(V) - return p92_fit(in_x) + return p92_fit(x) class F11_MWGC(BaseExtModel): @@ -1275,13 +1235,13 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ F11 MWGC function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -1297,11 +1257,6 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function allowing for spline interpolation f = interp1d(self.obsdata_x, self.obsdata_axav) @@ -1402,13 +1357,13 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ G21_MWAvg function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -1424,11 +1379,6 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # G21 parameters fit to the data using uncs as weights g21_fit = G21( scale=0.366, @@ -1445,7 +1395,7 @@ def evaluate(self, in_x): # return A(x)/A(V) # G21 a full dust_extinction model, hence send in x with units - return g21_fit(in_x) + return g21_fit(x) class D22_MWAvg(BaseExtModel): @@ -1526,13 +1476,13 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ D22_MWAvg function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -1548,11 +1498,6 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # setup the model d22_fit = PowerLaw1D(alpha=1.71, amplitude=0.386, x_0=1.0) @@ -1643,13 +1588,13 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ G24 SMCAvg function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -1677,7 +1622,7 @@ def evaluate(self, in_x): # return A(x)/A(V) return _curve_F99_method( - in_x, + x, self.Rv, C1, C2, @@ -1687,8 +1632,6 @@ def evaluate(self, in_x): gamma, optnir_axav_x, optnir_axav_y, - self.x_range, - self.__class__.__name__, ) @@ -1778,13 +1721,13 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ G24 SMCBumps function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -1812,7 +1755,7 @@ def evaluate(self, in_x): # return A(x)/A(V) return _curve_F99_method( - in_x, + x, self.Rv, C1, C2, @@ -1822,6 +1765,4 @@ def evaluate(self, in_x): gamma, optnir_axav_x, optnir_axav_y, - self.x_range, - self.__class__.__name__, ) diff --git a/dust_extinction/baseclasses.py b/dust_extinction/baseclasses.py index 9b88630..b771912 100644 --- a/dust_extinction/baseclasses.py +++ b/dust_extinction/baseclasses.py @@ -2,8 +2,9 @@ from scipy.interpolate import interp1d from astropy.modeling import Model, Parameter, InputParameterError +from astropy import units as u -from .helpers import _get_x_in_wavenumbers, _test_valid_x_range +from .helpers import _warn_no_units, _test_valid_x_range __all__ = ["BaseExtModel", "BaseExtRvModel", "BaseExtRvAfAModel", "BaseExtGrainModel"] @@ -15,6 +16,40 @@ class BaseExtModel(Model): n_inputs = 1 n_outputs = 1 + input_units = {"x": u.micron**-1} + return_units = {"y": u.dimensionless_unscaled} + input_units_equivalencies = {"x": u.spectral()} + _input_units_strict = True + _input_units_allow_dimensionless = True + + def _prepare_input_single(self, x): + """Check input units and bounds for a single input.""" + + # Get the value of the input in the internal units (1 / micron). + # Because we set the model's input_units_strict and + # input_units_allow_dimensionless to True, by this point one of the + # following must hold: + # - The input is in units of 1 / micron. + # - The input has units of None. + # - The input has units of dimensionless_unscaled. + # - The input is simple Numpy array and not a Quantity. + # In the last three cases, we raise a warning that we are assuming + # that the units are 1 /micron. + if not isinstance(x, u.Quantity): + _warn_no_units() + elif x.unit is None or x.unit is u.dimensionless_unscaled: + x = x.value + _warn_no_units() + else: + assert x.unit == self.input_units["x"] + x = x.value + + _test_valid_x_range(x, self.x_range, self.__class__.__name__) + return x + + def prepare_inputs(self, *args, **kwargs): + xs, *rest = super().prepare_inputs(*args, **kwargs) + return [self._prepare_input_single(x) for x in xs], *rest def extinguish(self, x, Av=None, Ebv=None): """ @@ -170,13 +205,13 @@ class BaseExtGrainModel(BaseExtModel): None """ - def evaluate(self, in_x): + def evaluate(self, x): """ Generic dust grain model function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -192,11 +227,6 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function allowing for spline interpolation # fill value needed to handle numerical issues at the edges # the x values has already been checked to be in range diff --git a/dust_extinction/helpers.py b/dust_extinction/helpers.py index c97d3f3..cca3c3a 100644 --- a/dust_extinction/helpers.py +++ b/dust_extinction/helpers.py @@ -2,46 +2,16 @@ import numpy as np from scipy.special import comb -import astropy.units as u from .warnings import SpectralUnitsWarning -__all__ = ["_get_x_in_wavenumbers", "_test_valid_x_range", "_smoothstep"] +__all__ = ["_warn_no_units", "_test_valid_x_range", "_smoothstep"] -def _get_x_in_wavenumbers(in_x): - """ - Convert input x to wavenumber given x has units. - Otherwise, assume x is in waveneumbers and issue a warning to this effect. - - Parameters - ---------- - in_x : astropy.quantity or simple floats - x values - - Returns - ------- - x : floats - input x values in wavenumbers w/o units - """ - # handles the case where x is a scaler - in_x = np.atleast_1d(in_x) - - # check if in_x is an astropy quantity, if not issue a warning - if not isinstance(in_x, u.Quantity): - warnings.warn( - "x has no units, assuming x units are inverse microns", - SpectralUnitsWarning - ) - - # convert to wavenumbers (1/micron) if x input in units - # otherwise, assume x in appropriate wavenumber units - with u.add_enabled_equivalencies(u.spectral()): - x_quant = u.Quantity(in_x, 1.0 / u.micron, dtype=np.float64) - - # strip the quantity to avoid needing to add units to all the - # polynomical coefficients - return x_quant.value +def _warn_no_units(): + warnings.warn( + "x has no units, assuming x units are inverse microns", SpectralUnitsWarning + ) def _test_valid_x_range(x, x_range, outname): diff --git a/dust_extinction/parameter_averages.py b/dust_extinction/parameter_averages.py index 6b1023e..07da621 100644 --- a/dust_extinction/parameter_averages.py +++ b/dust_extinction/parameter_averages.py @@ -8,7 +8,7 @@ from astropy.modeling.models import Drude1D, Polynomial1D, PowerLaw1D from .baseclasses import BaseExtRvModel, BaseExtRvAfAModel -from .helpers import _get_x_in_wavenumbers, _test_valid_x_range, _smoothstep +from .helpers import _smoothstep from .averages import G03_SMCBar from .shapes import _curve_F99_method, _modified_drude, FM90 @@ -88,13 +88,13 @@ class CCM89(BaseExtRvModel): x_range = x_range_CCM89 @staticmethod - def evaluate(in_x, Rv): + def evaluate(x, Rv): """ CCM89 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -110,15 +110,9 @@ def evaluate(in_x, Rv): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_CCM89, "CCM89") - # setup the a & b coefficient vectors - n_x = len(x) - a = np.zeros(n_x) - b = np.zeros(n_x) + a = np.zeros(x.shape) + b = np.zeros(x.shape) # define the ranges ir_indxs = np.where(np.logical_and(0.3 <= x, x < 1.1)) @@ -224,13 +218,13 @@ class O94(BaseExtRvModel): x_range = x_range_O94 @staticmethod - def evaluate(in_x, Rv): + def evaluate(x, Rv): """ O94 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -246,15 +240,9 @@ def evaluate(in_x, Rv): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_O94, "O94") - # setup the a & b coefficient vectors - n_x = len(x) - a = np.zeros(n_x) - b = np.zeros(n_x) + a = np.zeros(x.shape) + b = np.zeros(x.shape) # define the ranges ir_indxs = np.where(np.logical_and(0.3 <= x, x < 1.1)) @@ -362,13 +350,14 @@ class F99(BaseExtRvModel): Rv_range = [2.0, 6.0] x_range = x_range_F99 - def evaluate(self, in_x, Rv): + @staticmethod + def evaluate(x, Rv): """ F99 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -430,7 +419,7 @@ def evaluate(self, in_x, Rv): # return A(x)/A(V) return _curve_F99_method( - in_x, + x, Rv, C1, C2, @@ -440,8 +429,6 @@ def evaluate(self, in_x, Rv): gamma, optnir_axav_x, optnir_axebv_y / Rv, - self.x_range, - self.__class__.__name__, ) @@ -511,13 +498,14 @@ class F04(BaseExtRvModel): Rv_range = [2.0, 6.0] x_range = x_range_F04 - def evaluate(self, in_x, Rv): + @staticmethod + def evaluate(x, Rv): """ F04 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -578,7 +566,7 @@ def evaluate(self, in_x, Rv): # return A(x)/A(V) return _curve_F99_method( - in_x, + x, Rv, C1, C2, @@ -588,8 +576,6 @@ def evaluate(self, in_x, Rv): gamma, optnir_axav_x, optnir_axebv_y / Rv, - self.x_range, - self.__class__.__name__, ) @@ -657,13 +643,13 @@ class VCG04(BaseExtRvModel): x_range = x_range_VCG04 @staticmethod - def evaluate(in_x, Rv): + def evaluate(x, Rv): """ VCG04 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] internally wavenumbers are used @@ -678,17 +664,9 @@ def evaluate(in_x, Rv): ValueError Input x values outside of defined range """ - # convert to wavenumbers (1/micron) if x input in units - # otherwise, assume x in appropriate wavenumber units - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_VCG04, "VCG04") - # setup the a & b coefficient vectors - n_x = len(x) - a = np.zeros(n_x) - b = np.zeros(n_x) + a = np.zeros(x.shape) + b = np.zeros(x.shape) # define the ranges nuv_indxs = np.where(np.logical_and(3.3 <= x, x <= 8.0)) @@ -777,13 +755,13 @@ class GCC09(BaseExtRvModel): x_range = x_range_GCC09 @staticmethod - def evaluate(in_x, Rv): + def evaluate(x, Rv): """ GCC09 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -799,17 +777,9 @@ def evaluate(in_x, Rv): ValueError Input x values outside of defined range """ - # convert to wavenumbers (1/micron) if x input in units - # otherwise, assume x in appropriate wavenumber units - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_GCC09, "GCC09") - # setup the a & b coefficient vectors - n_x = len(x) - a = np.zeros(n_x) - b = np.zeros(n_x) + a = np.zeros(x.shape) + b = np.zeros(x.shape) # define the ranges nuv_indxs = np.where(np.logical_and(3.3 <= x, x <= 11.0)) @@ -914,13 +884,14 @@ class M14(BaseExtRvModel): Rv_range = [2.0, 6.0] x_range = x_range_M14 - def evaluate(self, in_x, Rv): + @staticmethod + def evaluate(x, Rv): """ M14 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -936,11 +907,6 @@ def evaluate(self, in_x, Rv): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # just in case someone calls evaluate explicitly Rv = np.atleast_1d(Rv) @@ -1157,13 +1123,13 @@ class G16(BaseExtRvAfAModel): x_range = x_range_G16 @staticmethod - def evaluate(in_x, RvA, fA): + def evaluate(x, RvA, fA): """ G16 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -1179,11 +1145,6 @@ def evaluate(in_x, RvA, fA): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_G16, "G16") - # just in case someone calls evaluate explicitly RvA = np.atleast_1d(RvA) @@ -1287,13 +1248,13 @@ def __init__(self, Rv=3.1, **kwargs): super().__init__(Rv, **kwargs) - def evaluate(self, in_x, Rv): + def evaluate(self, x, Rv): """ F19 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -1309,13 +1270,6 @@ def evaluate(self, in_x, Rv): ValueError Input x values outside of defined range """ - # convert to wavenumbers (1/micron) if x input in units - # otherwise, assume x in appropriate wavenumber units - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # just in case someone calls evaluate explicitly Rv = np.atleast_1d(Rv) @@ -1400,13 +1354,13 @@ def __init__(self, Rv=3.1, **kwargs): super().__init__(Rv, **kwargs) - def evaluate(self, in_x, Rv): + def evaluate(self, x, Rv): """ D22 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -1422,13 +1376,6 @@ def evaluate(self, in_x, Rv): ValueError Input x values outside of defined range """ - # convert to wavenumbers (1/micron) if x input in units - # otherwise, assume x in appropriate wavenumber units - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # just in case someone calls evaluate explicitly Rv = np.atleast_1d(Rv) @@ -1499,13 +1446,13 @@ class G23(BaseExtRvModel): Rv_range = [2.3, 5.6] x_range = x_range_G23 - def evaluate(self, in_x, Rv): + def evaluate(self, x, Rv): """ G23 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -1521,15 +1468,9 @@ def evaluate(self, in_x, Rv): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, "G23") - # setup the a & b coefficient vectors - n_x = len(x) - self.a = np.zeros(n_x) - self.b = np.zeros(n_x) + self.a = np.zeros(x.shape) + self.b = np.zeros(x.shape) # define the ranges ir_indxs = np.where(np.logical_and(1.0 / 35.0 <= x, x < 1.0 / 1.0)) diff --git a/dust_extinction/shapes.py b/dust_extinction/shapes.py index 6a7a74c..562999f 100644 --- a/dust_extinction/shapes.py +++ b/dust_extinction/shapes.py @@ -4,7 +4,7 @@ import astropy.units as u from astropy.modeling import Fittable1DModel, Parameter -from .helpers import _get_x_in_wavenumbers, _test_valid_x_range +from .baseclasses import BaseExtModel __all__ = ["FM90", "FM90_B3", "P92", "G21"] @@ -13,7 +13,7 @@ def _curve_F99_method( - in_x, + x, Rv, C1, C2, @@ -23,15 +23,13 @@ def _curve_F99_method( gamma, optnir_axav_x, optnir_axav_y, - valid_x_range, - model_name, ): """ Function to return extinction using F99 method Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -74,13 +72,8 @@ def _curve_F99_method( ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, valid_x_range, model_name) - # initialize extinction curve storage - axav = np.zeros(len(x)) + axav = np.zeros(x.shape) # x value above which FM90 parametrization used x_cutval_uv = 10000.0 / 2700.0 @@ -166,7 +159,7 @@ def _modified_drude(x, scale, x_o, gamma_o, asym): return y -class FM90(Fittable1DModel): +class FM90(BaseExtModel, Fittable1DModel): r""" Fitzpatrick & Massa (1990) 6 parameter ultraviolet shape model @@ -257,13 +250,13 @@ class FM90(Fittable1DModel): x_range = x_range_FM90 @staticmethod - def evaluate(in_x, C1, C2, C3, C4, xo, gamma): + def evaluate(x, C1, C2, C3, C4, xo, gamma): """ FM90 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -279,11 +272,6 @@ def evaluate(in_x, C1, C2, C3, C4, xo, gamma): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_FM90, "FM90") - # linear term exvebv = C1 + C2 * x @@ -301,12 +289,10 @@ def evaluate(in_x, C1, C2, C3, C4, xo, gamma): return exvebv @staticmethod - def fit_deriv(in_x, C1, C2, C3, C4, xo, gamma): + def fit_deriv(x, C1, C2, C3, C4, xo, gamma): """ Derivatives of the FM90 function with respect to the parameters """ - x = in_x - # useful quantitites x2 = x**2 xo2 = xo**2 @@ -346,7 +332,7 @@ def fit_deriv(in_x, C1, C2, C3, C4, xo, gamma): # 'C4': outputs_unit[self.outputs[0]]} -class FM90_B3(Fittable1DModel): +class FM90_B3(BaseExtModel, Fittable1DModel): r""" Fitzpatrick & Massa (1990) 6 parameter ultraviolet shape model Version with bump amplitude B3 = C3/gamma^2 @@ -438,13 +424,13 @@ class FM90_B3(Fittable1DModel): x_range = x_range_FM90 @staticmethod - def evaluate(in_x, C1, C2, B3, C4, xo, gamma): + def evaluate(x, C1, C2, B3, C4, xo, gamma): """ FM90 function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -460,11 +446,6 @@ def evaluate(in_x, C1, C2, B3, C4, xo, gamma): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_FM90, "FM90_B3") - # linear term exvebv = C1 + C2 * x @@ -482,7 +463,7 @@ def evaluate(in_x, C1, C2, B3, C4, xo, gamma): return exvebv -class P92(Fittable1DModel): +class P92(BaseExtModel, Fittable1DModel): r""" Pei (1992) 24 parameter shape model @@ -711,7 +692,7 @@ def _p92_single_term(in_lambda, amplitude, cen_wave, b, n): def evaluate( self, - in_x, + x, BKG_amp, BKG_lambda, BKG_b, @@ -742,7 +723,7 @@ def evaluate( Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -758,11 +739,6 @@ def evaluate( ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # calculate the terms lam = 1.0 / x axav = ( @@ -781,7 +757,7 @@ def evaluate( fit_deriv = None -class G21(Fittable1DModel): +class G21(BaseExtModel, Fittable1DModel): r""" Gordon et al. (2021) powerlaw plus two modified Drude profiles (for the 10 & 20 micron silicate features) @@ -894,7 +870,7 @@ class G21(Fittable1DModel): def evaluate( self, - in_x, + x, scale, alpha, sil1_amp, @@ -911,7 +887,7 @@ def evaluate( Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -925,11 +901,6 @@ def evaluate( ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, "G21") - wave = 1 / x # powerlaw diff --git a/dust_extinction/tests/test_ccm89.py b/dust_extinction/tests/test_ccm89.py index 47268d0..df834af 100644 --- a/dust_extinction/tests/test_ccm89.py +++ b/dust_extinction/tests/test_ccm89.py @@ -230,4 +230,3 @@ def test_extinction_CCM89_single_values(test_vals): # test np.testing.assert_allclose(tmodel(x), cor_val) - np.testing.assert_allclose(tmodel.evaluate(x, 3.1), cor_val) diff --git a/dust_extinction/tests/test_corvals.py b/dust_extinction/tests/test_corvals.py index ce9ef1c..0c29e4b 100644 --- a/dust_extinction/tests/test_corvals.py +++ b/dust_extinction/tests/test_corvals.py @@ -108,4 +108,3 @@ def test_corvals(model_class, test_parameters): # test single value evalutation for x, y in zip(x_vals, y_vals): np.testing.assert_allclose(tmodel(x), y, atol=tol) - np.testing.assert_allclose(tmodel.evaluate(x, Rv), y, atol=tol) diff --git a/dust_extinction/tests/test_corvals_aves.py b/dust_extinction/tests/test_corvals_aves.py index 34985eb..c9c2f4d 100644 --- a/dust_extinction/tests/test_corvals_aves.py +++ b/dust_extinction/tests/test_corvals_aves.py @@ -19,7 +19,6 @@ def test_corvals(model_class): # test single value evalutation for x, y in zip(x_vals, y_vals): np.testing.assert_allclose(tmodel(x), y, rtol=tol) - np.testing.assert_allclose(tmodel.evaluate(x), y, rtol=tol) @pytest.mark.parametrize("model_class", ave_models) diff --git a/dust_extinction/tests/test_fm90.py b/dust_extinction/tests/test_fm90.py index 30dd842..ebb717c 100644 --- a/dust_extinction/tests/test_fm90.py +++ b/dust_extinction/tests/test_fm90.py @@ -48,18 +48,6 @@ def test_extinction_FM90_single_values(xtest_vals): # test np.testing.assert_allclose(tmodel(x), cor_val) - np.testing.assert_allclose( - tmodel.evaluate( - x, - FM90.C1.default, - FM90.C2.default, - FM90.C3.default, - FM90.C4.default, - FM90.xo.default, - FM90.gamma.default, - ), - cor_val, - ) def test_FM90_fitting(): diff --git a/dust_extinction/tests/test_g16.py b/dust_extinction/tests/test_g16.py index e95e455..6d8102a 100644 --- a/dust_extinction/tests/test_g16.py +++ b/dust_extinction/tests/test_g16.py @@ -61,7 +61,6 @@ def test_extinction_G16_fA_1_single_values(test_vals): # test np.testing.assert_allclose(tmodel(x), cor_val, rtol=tolerance) - np.testing.assert_allclose(tmodel.evaluate(x, 3.1, 1.0), cor_val, rtol=tolerance) def test_extinction_G16_extinguish_values_Ebv(): diff --git a/dust_extinction/tests/test_gcc09.py b/dust_extinction/tests/test_gcc09.py index 741aa69..af1cd7f 100644 --- a/dust_extinction/tests/test_gcc09.py +++ b/dust_extinction/tests/test_gcc09.py @@ -130,4 +130,3 @@ def test_extinction_GCC09_single_values(test_vals): # test np.testing.assert_allclose(tmodel(x), cor_val, rtol=1e-5) - np.testing.assert_allclose(tmodel.evaluate(x, 3.1), cor_val, rtol=1e-5) diff --git a/dust_extinction/tests/test_p92.py b/dust_extinction/tests/test_p92.py index deaac79..1cc5468 100644 --- a/dust_extinction/tests/test_p92.py +++ b/dust_extinction/tests/test_p92.py @@ -112,38 +112,6 @@ def test_extinction_P92_single_values(xtest_vals): # test np.testing.assert_allclose(tmodel(x), cor_val, rtol=0.25, atol=0.01) - np.testing.assert_allclose( - tmodel.evaluate( - x, - P92.BKG_amp.default, - P92.BKG_lambda.default, - P92.BKG_b.default, - P92.BKG_n.default, - P92.FUV_amp.default, - P92.FUV_lambda.default, - P92.FUV_b.default, - P92.FUV_n.default, - P92.NUV_amp.default, - P92.NUV_lambda.default, - P92.NUV_b.default, - P92.NUV_n.default, - P92.SIL1_amp.default, - P92.SIL1_lambda.default, - P92.SIL1_b.default, - P92.SIL1_n.default, - P92.SIL2_amp.default, - P92.SIL2_lambda.default, - P92.SIL2_b.default, - P92.SIL2_n.default, - P92.FIR_amp.default, - P92.FIR_lambda.default, - P92.FIR_b.default, - P92.FIR_n.default, - ), - cor_val, - rtol=0.25, - atol=0.01, - ) @pytest.mark.skip(reason="failing due to an issue with the fitting") diff --git a/dust_extinction/tests/test_vcg04.py b/dust_extinction/tests/test_vcg04.py index fd900de..86db990 100644 --- a/dust_extinction/tests/test_vcg04.py +++ b/dust_extinction/tests/test_vcg04.py @@ -70,4 +70,3 @@ def test_extinction_VCG04_single_values(test_vals): # test np.testing.assert_allclose(tmodel(x), cor_val, rtol=1e-5) - np.testing.assert_allclose(tmodel.evaluate(x, 3.1), cor_val, rtol=1e-5)