Skip to content

Commit

Permalink
WIP: deprecate helper-specific methods in specviz/2d
Browse files Browse the repository at this point in the history
  • Loading branch information
kecnry committed Jan 14, 2025
1 parent 2084d96 commit 819abf7
Show file tree
Hide file tree
Showing 17 changed files with 328 additions and 233 deletions.
6 changes: 3 additions & 3 deletions docs/cubeviz/export_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,20 @@ For a list of model labels:

.. code-block:: python
models = cubeviz.get_models()
models = cubeviz.plugins['Model Fitting'].get_models()
models
Once you know the model labels, to get a specific model:

.. code-block:: python
mymodel = cubeviz.get_models(model_label="ModelLabel", x=10)
mymodel = cubeviz.plugins['Model Fitting'].get_models(model_label="ModelLabel", x=10)
To extract all of the model parameters:

.. code-block:: python
myparams = cubeviz.get_model_parameters(model_label="ModelLabel", x=x, y=y)
myparams = cubeviz.plugins['Model Fitting'].get_model_parameters(model_label="ModelLabel", x=x, y=y)
myparams
where the ``model_label`` parameter identifies which model should be returned and
Expand Down
6 changes: 3 additions & 3 deletions docs/specviz/export_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,20 @@ For a list of model labels:

.. code-block:: python
models = specviz.get_models()
models = specviz.plugins['Model Fitting'].get_models()
models
Once you know the model labels, to get a specific model:

.. code-block:: python
mymodel = specviz.get_models(model_label="ModelLabel")
mymodel = specviz.plugins['Model Fitting'].get_models(model_label="ModelLabel")
To extract all of the model parameters:

.. code-block:: python
myparams = specviz.get_model_parameters(model_label="ModelLabel")
myparams = specviz.plugins['Model Fitting'].get_model_parameters(model_label="ModelLabel")
myparams
where the ``model_label`` parameter identifies which model should be returned.
Expand Down
9 changes: 6 additions & 3 deletions jdaviz/configs/cubeviz/helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from astropy.utils.decorators import deprecated

from jdaviz.configs.default.plugins.line_lists.line_list_mixin import LineListMixin
from jdaviz.configs.specviz import Specviz
from jdaviz.core.events import AddDataMessage, SnackbarMessage
Expand Down Expand Up @@ -97,6 +99,7 @@ def load_data(self, data, data_label=None, override_cube_limit=False, **kwargs):
color='warning', sender=self, timeout=10000)
self.app.hub.broadcast(msg)

@deprecated(since="4.2", alternative="plugins['Slice'].value")
def select_wavelength(self, wavelength):
"""
Select the slice closest to the provided wavelength.
Expand All @@ -112,6 +115,7 @@ def select_wavelength(self, wavelength):
self.select_slice(wavelength)

@property
@deprecated(since="4.2", alternative="viewers['spectrum-viewer']")
def specviz(self):
"""
A Specviz helper (:class:`~jdaviz.configs.specviz.helper.Specviz`) for the Jdaviz
Expand Down Expand Up @@ -149,8 +153,7 @@ def get_data(self, data_label=None, spatial_subset=None, spectral_subset=None,
spectral_subset=spectral_subset,
cls=cls, use_display_units=use_display_units)

# Need this method for Imviz Aperture Photometry plugin.

@deprecated(since="4.2", alternative="plugins['Aperture Photometry'].export_table()")
def get_aperture_photometry_results(self):
"""Return aperture photometry results, if any.
Results are calculated using :ref:`cubeviz-aper-phot` plugin.
Expand All @@ -161,4 +164,4 @@ def get_aperture_photometry_results(self):
Photometry results if available or `None` otherwise.
"""
return self.plugins['Aperture Photometry']._obj.export_table()
return self.plugins['Aperture Photometry'].export_table()
6 changes: 3 additions & 3 deletions jdaviz/configs/cubeviz/plugins/tests/test_cubeviz_aperphot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_cubeviz_aperphot_cube_orig_flux(cubeviz_helper, image_cube_hdu_obj_micr
plg.dataset_selected = "test[FLUX]"
plg.aperture_selected = "Subset 1"
plg.vue_do_aper_phot()
row = cubeviz_helper.get_aperture_photometry_results()[0]
row = plg.export_table()[0]

# Basically, we should recover the input rectangle here.
assert_allclose(row["xcenter"], 1 * u.pix)
Expand All @@ -51,7 +51,7 @@ def test_cubeviz_aperphot_cube_orig_flux(cubeviz_helper, image_cube_hdu_obj_micr
cube_slice_plg = cubeviz_helper.plugins["Slice"]._obj
cube_slice_plg.vue_goto_first()
plg.vue_do_aper_phot()
row = cubeviz_helper.get_aperture_photometry_results()[1]
row = plg.export_table()[1]

# Same rectangle but different slice value.
assert_allclose(row["xcenter"], 1 * u.pix)
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_cubeviz_aperphot_cube_orig_flux(cubeviz_helper, image_cube_hdu_obj_micr
plg.dataset_selected = "test[FLUX] collapsed"
plg.aperture_selected = "Subset 1"
plg.vue_do_aper_phot()
row = cubeviz_helper.get_aperture_photometry_results()[2]
row = plg.export_table()[2]

# Basically, we should recover the input rectangle here.
assert_allclose(row["xcenter"], 1 * u.pix)
Expand Down
2 changes: 0 additions & 2 deletions jdaviz/configs/default/plugins/data_quality/data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ class DataQuality(PluginTemplateMixin, ViewerSelectMixin):
"""
template_file = __file__, "data_quality.vue"

irrelevant_msg = Unicode("").tag(sync=True)

# `layer` is the science data layer
science_layer_multiselect = Bool(False).tag(sync=True)
science_layer_items = List().tag(sync=True)
Expand Down
210 changes: 209 additions & 1 deletion jdaviz/configs/default/plugins/model_fitting/model_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class ModelFitting(PluginTemplateMixin, DatasetSelectMixin,
Label of the residuals to apply when calling :meth:`calculate_fit` if ``residuals_calculate``
is ``True``.
* :meth:`calculate_fit`
* :meth:`fitted_models`
* :meth:`get_models`
* :meth:`get_model_parameters`
"""
dialog = Bool(False).tag(sync=True)
template_file = __file__, "model_fitting.vue"
Expand Down Expand Up @@ -187,7 +190,8 @@ def user_api(self):
'get_model_component', 'set_model_component', 'reestimate_model_parameters',
'equation', 'equation_components',
'add_results', 'residuals_calculate', 'residuals']
expose += ['calculate_fit', 'clear_table', 'export_table']
expose += ['calculate_fit', 'clear_table', 'export_table',
'fitted_models', 'get_models', 'get_model_parameters']
return PluginUserApi(self, expose=expose)

def _param_units(self, param, model_type=None):
Expand Down Expand Up @@ -779,6 +783,210 @@ def equation_components(self):
"""
return re.split(r'[+*/-]', self.equation.value.replace(' ', ''))

@property
def fitted_models(self):
"""
Dictionary of all previously fitted models.
"""
# TODO: store this internally instead of within the app
return self.app.fitted_models

def get_models(self, models=None, model_label=None, x=None, y=None):
"""
Loop through all models and output models of the label model_label.
If x or y is set, return model_labels of those (x, y) coordinates.
If x and y are None, print all models regardless of coordinates.
Parameters
----------
models : dict
A dict of models, with the key being the label name and the value
being an `astropy.modeling.CompoundModel` object. Defaults to
`fitted_models` if no parameter is provided.
model_label : str
The name of the model that will be found and returned. If it
equals default, every model present will be returned.
x : int
The x coordinate of the model spaxels that will be returned.
y : int
The y coordinate of the model spaxels that will be returned.
Returns
-------
selected_models : dict
Dictionary of the selected models.
"""
selected_models = {}
# If models is not provided, use the app's fitted models
if not models:
models = self.fitted_models

# Loop through all keys in the dict models
for label in models:
# Prevent "Model 2" from being returned when model_label is "Model"
if model_label is not None:
if label.split(" (")[0] != model_label:
continue

# If no label was provided, use label name without coordinates.
if model_label is None and " (" in label:
find_label = label.split(" (")[0]
# If coordinates are not present, just use the label.
elif model_label is None:
find_label = label
else:
find_label = model_label

# If x and y are set, return keys that match the model plus that
# coordinate pair. If only x or y is set, return keys that fit
# that value for the appropriate coordinate.
if x is not None and y is not None:
find_label = r"{} \({}, {}\)".format(find_label, x, y)
elif x:
find_label = r"{} \({}, .+\)".format(find_label, x)
elif y:
find_label = r"{} \(.+, {}\)".format(find_label, y)

if re.search(find_label, label):
selected_models[label] = models[label]

return selected_models

def get_model_parameters(self, models=None, model_label=None, x=None, y=None):
"""
Convert each parameter of model inside models into a coordinate that
maps the model name and parameter name to a `astropy.units.Quantity`
object.
Parameters
----------
models : dict
A dictionary where the key is a model name and the value is an
`astropy.modeling.CompoundModel` object.
model_label : str
Get model parameters for a particular model by inputting its label.
x : int
The x coordinate of the model spaxels that will be returned from
get_models.
y : int
The y coordinate of the model spaxels that will be returned from
get_models.
Returns
-------
:dict: a dictionary of the form
{model name: {parameter name: [[`astropy.units.Quantity`]]}}
for 3d models or
{model name: {parameter name: `astropy.units.Quantity`}} where the
Quantity object represents the parameter value and unit of one of
spaxel models or the 1d models, respectively.
"""
if models and model_label:
models = self.get_models(models=models, model_label=model_label, x=x, y=y)
elif models is None and model_label:
models = self.get_models(model_label=model_label, x=x, y=y)
elif models is None:
models = self.fitted_models

data_shapes = {}
for label in models:
data_label = label.split(" (")[0]
if data_label not in data_shapes:
data_shapes[data_label] = self.app.data_collection[data_label].data.shape

param_dict = {}
parameters_cube = {}
param_x_y = {}
param_units = {}

for label in models:
# 3d models take the form of "Model (1,2)" so this if statement
# looks for that style and separates out the pertinent information.
if " (" in label:
label_split = label.split(" (")
model_name = label_split[0]
x = int(label_split[1].split(", ")[0])
y = int(label_split[1].split(", ")[1][:-1])

# x and y values are added to this dict where they will be used
# to convert the models of each spaxel into a single
# coordinate in the parameters_cube dictionary.
if model_name not in param_x_y:
param_x_y[model_name] = {'x': [], 'y': []}
if x not in param_x_y[model_name]['x']:
param_x_y[model_name]['x'].append(x)
if y not in param_x_y[model_name]['y']:
param_x_y[model_name]['y'].append(y)

# 1d models will be handled by this else statement.
else:
model_name = label

if model_name not in param_dict:
param_dict[model_name] = list(models[label].param_names)

# This adds another dictionary as the value of
# parameters_cube[model_name] where the key is the parameter name
# and the value is either a 2d array of zeros or a single value, depending
# on whether the model in question is 3d or 1d, respectively.
for model_name in param_dict:
if model_name in param_x_y:
parameters_cube[model_name] = {x: np.zeros(shape=data_shapes[model_name][:2])
for x in param_dict[model_name]}
else:
parameters_cube[model_name] = {x: 0
for x in param_dict[model_name]}

# This loop handles the actual placement of param.values and
# param.units into the parameter_cubes dictionary.
for label in models:
if " (" in label:
label_split = label.split(" (")
model_name = label_split[0]

# If the get_models method is used to build a dictionary of
# models and a value is set for the x or y parameters, that
# will mean that only one x or y value is present in the
# models.
if len(param_x_y[model_name]['x']) == 1:
x = 0
else:
x = int(label_split[1].split(", ")[0])

if len(param_x_y[model_name]['y']) == 1:
y = 0
else:
y = int(label_split[1].split(", ")[1][:-1])

param_units[model_name] = {}

for name in param_dict[model_name]:
param = getattr(models[label], name)
parameters_cube[model_name][name][x][y] = param.value
param_units[model_name][name] = param.unit
else:
model_name = label
param_units[model_name] = {}

# 1d models do not have anything set of param.unit, so the
# return_units and input_units properties need to be used
# instead, depending on the type of parameter `name` is.
for name in param_dict[model_name]:
param = getattr(models[label], name)
parameters_cube[model_name][name] = param.value
param_units[model_name][name] = param.unit

# Convert values of parameters_cube[key][param_name] into u.Quantity
# objects that contain the appropriate unit set in
# param_units[key][param_name]
for key in parameters_cube:
for param_name in parameters_cube[key]:
parameters_cube[key][param_name] = u.Quantity(
parameters_cube[key][param_name],
param_units[key].get(param_name, None))

return parameters_cube

def vue_add_model(self, event):
self.create_model_component()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_parameter_retrieval(cubeviz_helper, spectral_cube_wcs):
warnings.filterwarnings('ignore', message='Model is linear in parameters.*')
plugin.calculate_fit()

params = cubeviz_helper.get_model_parameters()
params = cubeviz_helper.plugins['Model Fitting'].get_model_parameters()
slope_res = np.zeros((3, 4))
slope_res[2, 2] = 1.0

Expand Down
20 changes: 19 additions & 1 deletion jdaviz/configs/default/plugins/viewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def user_api(self):
elif isinstance(self, TableViewer):
expose += []
else:
expose += ['set_limits', 'reset_limits']
expose += ['set_limits', 'reset_limits', 'set_tick_format']
return ViewerUserApi(self, expose=expose)

@property
Expand Down Expand Up @@ -180,6 +180,24 @@ def get_limits(self):
"""
return self.state.x_min, self.state.x_max, self.state.y_min, self.state.y_max

def set_tick_format(self, fmt, axis):
"""
Manually set the tick format of one of the axes.
Parameters
----------
fmt : str
Format of tick marks.
For example, ``'0.1e'`` to set scientific notation or ``'0.2f'`` to turn it off.
axis : {x, y}
The viewer axis.
"""
if axis not in ('x', 'y'):
raise ValueError("axis must be 'x' or 'y'")
# Examples of values for fmt are '0.1e' or '0.2f'
axis = {'x': 0, 'y': 1}[axis]
self.figure.axes[axis].tick_format = fmt

@property
def native_marks(self):
"""
Expand Down
Loading

0 comments on commit 819abf7

Please sign in to comment.