Skip to content

Commit

Permalink
added docstrings to functions and moved some functions to helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
comane committed Jan 22, 2025
1 parent 16fe05f commit 6b6a612
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 68 deletions.
93 changes: 25 additions & 68 deletions validphys2/src/validphys/closuretest/multiclosure_nsigma.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
"""
This module contains the functions to compute the consistency / inconsistency sets.
Assuming that we have two datasets A and B, and that we are investigating whether A is consistent or not
we can define the following sets:
1⍺ = {i | nσi > Z⍺} ...
TODO
"""

import dataclasses
Expand Down Expand Up @@ -83,43 +78,6 @@ def chi2_nsigma_deviation(central_member_chi2: CentralChi2Data) -> float:
fits_data = collect("data", ("fits", "fitinputcontext"))


def is_weighted(fits_data: list) -> bool:
"""
Returns whether the considered multiclosure tests has been weighted or not.
If the weighted datasets are not the same for all fits,
or there is more than one weighted dataset, an error is raised.
Parameters
----------
fits_data: list
List of data for each fit.
Returns
-------
str or None
Name of the weighted dataset.
"""
# Extract the set of unique weighted dataset names from all fits
weighted_ds_sets = [{ds.name for ds in data.datasets if ds.weight != 1} for data in fits_data]

# Ensure all fits have the same set of weighted datasets
if len(set(frozenset(ds_set) for ds_set in weighted_ds_sets)) > 1:
error_msg = "Weighted datasets are not the same for all fits in the same multiclosure test (dataspec)."
log.error(error_msg)
raise ValueError(error_msg)

# Extract the single weighted dataset set (all should be identical)
weighted_ds = next(iter(weighted_ds_sets))

# Ensure there is exactly one weighted dataset
if len(weighted_ds) > 1:
error_msg = "Only one dataset can be weighted in a multiclosure test."
log.error(error_msg)
raise ValueError(error_msg)

return bool(weighted_ds)


@dataclasses.dataclass
class MulticlosureNsigma:
"""
Expand Down Expand Up @@ -179,23 +137,6 @@ def multiclosurefits_nsigma(
dataspecs_multiclosurefits_nsigma = collect("multiclosurefits_nsigma", ("dataspecs",))


def n_fits(dataspecs):
"""
Computes the total number of fits in the multiclosure test.
If the number of fits is not the same across dataspecs it raises an error.
"""
n_fits = set()
for dataspec in dataspecs:
n_fits.add(len(dataspec['fits']))

if len(n_fits) > 1:
error_msg = "The number of fits is not the same across dataspecs."
log.error(error_msg)
raise ValueError(error_msg)

return next(iter(n_fits))


@dataclasses.dataclass
class NsigmaAlpha:
"""
Expand Down Expand Up @@ -228,7 +169,7 @@ def def_of_nsigma_alpha(
The name of the weighted dataset.
complement: bool, default=False
Whether to compute the complement set 1 alpha values.
Returns
-------
NsigmaAlpha
Expand Down Expand Up @@ -275,13 +216,18 @@ def comp_nsigma_alpha(multiclosurefits_nsigma: pd.DataFrame, weighted_dataset: s

def set_1_alpha(dataspecs_nsigma_alpha: list) -> dict:
"""
Returns the set 1 alpha values.
Returns the set 1 alpha values, these are defined as
1_{\alpha} = {i | n_{\sigma}^{i} > Z_{\alpha}}
where i is the index of the fit and n_{\sigma}^{i} is the n-sigma value computed
for fit i.
Parameters
----------
dataspecs_nsigma_alpha: list
List of NsigmaAlpha dataclasses.
Returns
-------
dict
Expand All @@ -293,13 +239,18 @@ def set_1_alpha(dataspecs_nsigma_alpha: list) -> dict:

def set_3_alpha(dataspecs_nsigma_alpha: list) -> dict:
"""
Same as the set 1 alpha values, but for the weighted datasets.
Same as the set 1 alpha values, but for the weighted fits.
3_{\alpha} = {i | n_{weighted, \sigma}^{i} > Z_{\alpha}}
where i is the index of the fit and n_{weighted, \sigma}^{i} is the n-sigma value computed
on the weighted dataset for fit i.
Parameters
----------
dataspecs_nsigma_alpha: list
List of NsigmaAlpha dataclasses.
Returns
-------
dict
Expand All @@ -320,7 +271,7 @@ def comp_set_1_alpha(dataspecs_comp_nsigma_alpha: list) -> dict:

def comp_set_3_alpha(dataspecs_comp_nsigma_alpha: list) -> dict:
"""
Same as the complement set 1 alpha values, but for the weighted datasets.
Returns the complement set 3 alpha values.
"""
for dataspec_nsigma in dataspecs_comp_nsigma_alpha:
if dataspec_nsigma.is_weighted:
Expand All @@ -342,7 +293,7 @@ def def_set_2(
The name of the weighted dataset.
complement: bool, default=False
Whether to compute the complement set 2 alpha values.
Returns
-------
dict
Expand Down Expand Up @@ -384,7 +335,13 @@ def def_set_2(

def set_2_alpha(dataspecs_multiclosurefits_nsigma: list, weighted_dataset: str) -> dict:
"""
Computes the set 2 alpha values.
Computes the set 2 alpha values. The set 2 is defined as:
2_{\alpha} = {i | n_{weighted, \sigma}^{i} - n_{ref, \sigma}^{i}> + Z_{\alpha}}
where the n-sigma is computed on all datasets that are not the weighted dataset.
Moreover if for a fit i any dataset has a n-sigma value greater than Z_{\alpha}, then
the fit i is included in the set.
"""
return def_set_2(dataspecs_multiclosurefits_nsigma, weighted_dataset, complement=False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,57 @@ def compute_nsigma_critical_value(

z_alpha = norm.ppf(1 - alpha)
return c, z_alpha


def is_weighted(fits_data: list) -> bool:
"""
Returns whether the considered multiclosure tests has been weighted or not.
If the weighted datasets are not the same for all fits,
or there is more than one weighted dataset, an error is raised.
Parameters
----------
fits_data: list
List of data for each fit.
Returns
-------
str or None
Name of the weighted dataset.
"""
# Extract the set of unique weighted dataset names from all fits
weighted_ds_sets = [{ds.name for ds in data.datasets if ds.weight != 1} for data in fits_data]

# Ensure all fits have the same set of weighted datasets
if len(set(frozenset(ds_set) for ds_set in weighted_ds_sets)) > 1:
error_msg = "Weighted datasets are not the same for all fits in the same multiclosure test (dataspec)."
log.error(error_msg)
raise ValueError(error_msg)

# Extract the single weighted dataset set (all should be identical)
weighted_ds = next(iter(weighted_ds_sets))

# Ensure there is exactly one weighted dataset
if len(weighted_ds) > 1:
error_msg = "Only one dataset can be weighted in a multiclosure test."
log.error(error_msg)
raise ValueError(error_msg)

return bool(weighted_ds)


def n_fits(dataspecs):
"""
Computes the total number of fits in the multiclosure test.
If the number of fits is not the same across dataspecs it raises an error.
"""
n_fits = set()
for dataspec in dataspecs:
n_fits.add(len(dataspec['fits']))

if len(n_fits) > 1:
error_msg = "The number of fits is not the same across dataspecs."
log.error(error_msg)
raise ValueError(error_msg)

return next(iter(n_fits))

0 comments on commit 6b6a612

Please sign in to comment.