From 01ffd200930437d198ca03d200eabc240280128f Mon Sep 17 00:00:00 2001 From: tomsail Date: Mon, 11 Nov 2024 12:58:51 +0100 Subject: [PATCH] feat: streamlined API, tests and README file --- README.md | 168 ++++++++++++++---------------- seastats/__init__.py | 197 +++++++++++++++++++++++++++++++++++- seastats/stats.py | 130 ------------------------ seastats/storms.py | 64 ------------ tests/compute_stats_test.py | 27 ++--- 5 files changed, 278 insertions(+), 308 deletions(-) diff --git a/README.md b/README.md index 6d95210..64c970c 100644 --- a/README.md +++ b/README.md @@ -4,67 +4,60 @@ * `sim`: modelled surge time series * `mod`: observed surge time series -The main functions are: -* `get_stats()`: to get [general comparison metrics](#general-metrics) between two time series -* `match_extremes()`: a PoT selection is done on the observed signal. Function returns the decreasing extreme event peak values for observed and modeled signals (and time lag between events). See details [below](#extreme-events) -* `storm_metrics()`: functions returns storm metrics. See details [below](#storm-metrics) - -# General metrics +The main function is: ```python -stats = get_stats(sim: pd.Series, obs: pd.Series) -``` -returns the following dictionary: - +def get_stats( + sim: Series, + obs: Series, + metrics: Sequence[str] = None, + quantile: float = 0, + cluster: int = 72, + round: int = -1 +) -> dict[str, float] ``` -{ - 'bias': 0.01, - 'rmse': 0.105, - 'rms': 0.106, - 'rms_95': 0.095, - 'sim_mean': 0.01, - 'obs_mean': -0.0, - 'sim_std': 0.162, - 'obs_std': 0.142, - 'nse': 0.862, - 'lamba': 0.899, - 'cr': 0.763, - 'cr_95': 0.489, - 'slope': 0.215, - 'intercept': 0.01, - 'slope_pp': 0.381, - 'intercept_pp': 0.012, - 'mad': 0.062, - 'madp': 0.207, - 'madc': 0.269, - 'kge': 0.71 -} -``` - -with: -* `bias`: Bias -* `rmse`: Root Mean Square Error -* `rms`: Root Mean Square -* `rms_95`: Root Mean Square for data points above 95th percentile -* `sim_mean`: Mean of simulated values -* `obs_mean`: Mean of observed values -* `sim_std`: Standard deviation of simulated values -* `obs_std`: Standard deviation of observed values -* `nse`: Nash-Sutcliffe Efficiency -* `lamba`: Lambda index -* `cr`: Pearson Correlation coefficient -* `cr_95`: Pearson Correlation coefficient for data points above 95th percentile -* `slope`: Slope of Model/Obs correlation -* `intercept`: Intercept of Model/Obs correlation -* `slope_pp`: Slope of Model/Obs correlation of percentiles -* `intercept_pp`: Intercept of Model/Obs correlation of percentiles -* `mad`: Mean Absolute Deviation -* `madp`: Mean Absolute Deviation of percentiles -* `madc`: `mad + madp` -* `kge`: Kling–Gupta Efficiency - -Most of the paremeters are detailed below: - +Calculates various statistical metrics between the simulated and observed time series data. +## Parameters: + * **sim** (pd.Series). The simulated time series data. + * **obs** (pd.Series). The observed time series data. + * **metrics** (list[str]). (Optional) The list of statistical metrics to calculate. If None, all items in SUGGESTED_METRICS will be calculated. If ["all"], all items in SUPPORTED_METRICS will be calculated. Default is None. + * **quantile** (float). (Optional) Quantile used to calculate the metrics. Default is 0 (no selection) + * **cluster** (int). (Optional) Cluster duration for grouping storm events. Default is 72 hours. + * **round** (int). (Optional) Apply rounding to the results to. Default is no rounding (value is -1) + +Returns a dictionary containing the calculated metrics and their corresponding values. With 2 types of metrics: +* [The "general" metrics](#general-metrics): All the basic metrics needed for signal comparison (RMSE, RMS, Correlation etc..). See details below + * `bias`: Bias + * `rmse`: Root Mean Square Error + * `rms`: Root Mean Square + * `rms_95`: Root Mean Square for data points above 95th percentile + * `sim_mean`: Mean of simulated values + * `obs_mean`: Mean of observed values + * `sim_std`: Standard deviation of simulated values + * `obs_std`: Standard deviation of observed values + * `mae`: Mean Absolute Error + * `mse`: Mean Square Error + * `nse`: Nash-Sutcliffe Efficiency + * `lamba`: Lambda index + * `cr`: Pearson Correlation coefficient + * `cr_95`: Pearson Correlation coefficient for data points above 95th percentile + * `slope`: Slope of Model/Obs correlation + * `intercept`: Intercept of Model/Obs correlation + * `slope_pp`: Slope of Model/Obs correlation of percentiles + * `intercept_pp`: Intercept of Model/Obs correlation of percentiles + * `mad`: Mean Absolute Deviation + * `madp`: Mean Absolute Deviation of percentiles + * `madc`: `mad + madp` + * `kge`: Kling–Gupta Efficiency +* [The storm metrics](#storm-metrics): a PoT selection is done on the observed signal (using the `match_extremes()` function). Function returns the decreasing extreme event peak values for observed and modeled signals (and time lag between events). See details below. + * `R1`: Difference between observed and modelled for the biggest storm + * `R1_norm`: Normalized R1 (R1 divided by observed value) + * `R3`: Average difference between observed and modelled for the three biggest storms + * `R3_norm`: Normalized R3 (R3 divided by observed value) + * `error`: Average difference between observed and modelled for all storms + * `error_norm`: Normalized error (error divided by observed value) + +## General metrics ### A. Dimensional Statistics: #### Mean Error (or Bias) $$\langle x_c - x_m \rangle = \langle x_c \rangle - \langle x_m \rangle$$ @@ -88,7 +81,25 @@ with : $$\lambda = 1 - \frac{\sum{(x_c - x_m)^2}}{\sum{(x_m - \overline{x}_m)^2} + \sum{(x_c - \overline{x}_c)^2} + n(\overline{x}_m - \overline{x}_c)^2 + \kappa}$$ * with `kappa` $$2 \cdot \left| \sum{((x_m - \overline{x}_m) \cdot (x_c - \overline{x}_c))} \right|$$ -# Extreme events +## Storm metrics +The functions uses the `match_extremes()` function (detailed below) and returns: + * `R1`: the error for the biggest storm + * `R3`: the mean error for the 3 biggest storms + * `error`: the mean error for all the storms above the threshold. + * `R1_norm`/`R3_norm`/`error`: Same methodology, but values are in normalised (in %) relatively to the observed peaks. + + +### case of NaNs +The `storm_metrics()` might return: +```python +{'R1': np.nan, + 'R1_norm': np.nan, + 'R3': np.nan, + 'R3_norm': np.nan, + 'error': np.nan, + 'error_norm': np.nan} +``` +## Extreme events Example of implementation: ```python @@ -112,38 +123,6 @@ with: NB: the function uses [pyextremes](https://georgebv.github.io/pyextremes/quickstart/) in the background, with PoT method, using the `quantile` value of the observed signal as physical threshold and passes the `cluster_duration` argument. -# Storm metrics -The functions uses the above `match_extremes()` and returns: - * `R1`: the error for the biggest storm - * `R3`: the mean error for the 3 biggest storms - * `error`: the mean error for all the storms above the threshold. - * `R1_norm`/`R3_norm`/`error`: Same methodology, but values are in normalised (in %) relatively to the observed peaks. - -Example of implementation: -```python -from seastats.storms import storm_metrics -storm_metrics(sim: pd.Series, obs: pd.Series, quantile: float, cluster_duration:int = 72) -``` -returns this dictionary: -```python -{'R1': 0.237, - 'R1_norm': 0.296, - 'R3': 0.147, - 'R3_norm': 0.207, - 'error': 0.0938, - 'error_norm': 0.178} -``` - -### case of NaNs -The `storm_metrics()` might return: -```python -{'R1': np.nan, - 'R1_norm': np.nan, - 'R3': np.nan, - 'R3_norm': np.nan, - 'error': np.nan, - 'error_norm': np.nan} -``` this happens when the function `storms/match_extremes.py` couldn't finc concomitent storms for the observed and modeled time series. @@ -152,9 +131,10 @@ see [notebook](/notebooks/example_abed.ipynb) for details get all metrics in a 3 liner: ```python -stats = get_stats(sim, obs) -metrics = storm_metrics(sim, obs, quantile=0.99, cluster=72) -pd.DataFrame(dict(stats, **metrics), index=['abed']) +from seastats import get_stats, GENERAL_METRICS_ALL, STORM_METRICS_ALL +general = get_stats(sim, obs, metrics = GENERAL_METRICS) +storm = get_stats(sim, obs, quantile = 0.99, metrics = STORM_METRICS) # we use a different quantile for PoT selection +pd.DataFrame(dict(general, **storm), index=['abed']) ``` | | bias | rmse | rms | rms_95 | sim_mean | obs_mean | sim_std | obs_std | nse | lamba | cr | cr_95 | slope | intercept | slope_pp | intercept_pp | mad | madp | madc | kge | R1 | R1_norm | R3 | R3_norm | error | error_norm | diff --git a/seastats/__init__.py b/seastats/__init__.py index 85c5cf1..8629f48 100644 --- a/seastats/__init__.py +++ b/seastats/__init__.py @@ -1,7 +1,196 @@ -from seastats.stats import GENERAL_METRICS -from seastats.stats import GENERAL_METRICS_ALL -from seastats.storms import STORM_METRICS -from seastats.storms import STORM_METRICS_ALL +from __future__ import annotations + +import logging +from collections.abc import Sequence + +import numpy as np +import pandas as pd + +from seastats.stats import get_bias +from seastats.stats import get_corr +from seastats.stats import get_kge +from seastats.stats import get_lambda +from seastats.stats import get_mad +from seastats.stats import get_madc +from seastats.stats import get_madp +from seastats.stats import get_mae +from seastats.stats import get_mse +from seastats.stats import get_nse +from seastats.stats import get_rms +from seastats.stats import get_rmse +from seastats.stats import get_slope_intercept +from seastats.stats import get_slope_intercept_pp +from seastats.storms import match_extremes + +logger = logging.getLogger(__name__) + +GENERAL_METRICS_ALL = [ + "bias", + "rmse", + "rms", + "rms_qm", + "sim_mean", + "obs_mean", + "sim_std", + "obs_std", + "mae", + "mse", + "nse", + "lamba", + "cr", + "cr_qm", + "slope", + "intercept", + "slope_pp", + "intercept_pp", + "mad", + "madp", + "madc", + "kge", +] +GENERAL_METRICS = ["bias", "rms", "rmse", "cr", "nse", "kge"] +STORM_METRICS = ["R1", "R3", "error"] +STORM_METRICS_ALL = ["R1", "R1_norm", "R3", "R3_norm", "error" "error_norm"] SUGGESTED_METRICS = sorted(GENERAL_METRICS + STORM_METRICS) SUPPORTED_METRICS = sorted(GENERAL_METRICS_ALL + STORM_METRICS_ALL) + + +def get_stats( + sim: pd.Series, + obs: pd.Series, + metrics: Sequence[str] = None, + quantile: float = 0, + cluster: int = 72, + round: int = -1, +) -> dict[str, float]: + """ + Calculates various statistical metrics between the simulated and observed time series data. + + :param pd.Series sim: The simulated time series data. + :param pd.Series obs: The observed time series data. + :param list[str] metrics: (Optional) The list of statistical metrics to calculate. If None, all items in SUGGESTED_METRICS will be calculated. If ["all"], all items in SUPPORTED_METRICS will be calculated. Default is None. + :param float quantile: (Optional) Quantile used to calculate the metrics. Default is 0 (no selection) + :param int cluster: (Optional) Cluster duration for grouping storm events. Default is 72 hours. + :param int round: (Optional) Apply rounding to the results to. Default is no rounding (value is -1) + + :return dict stats: dictionary containing the calculated metrics. + + The dictionary contains the following keys and their corresponding values: + + - `bias`: The bias between the simulated and observed time series data. + - `rmse`: The Root Mean Square Error between the simulated and observed time series data. + - `mae`: The Mean Absolute Error the simulated and observed time series data. + - `mse`: The Mean Square Error the simulated and observed time series data. + - `rms`: The raw mean square error between the simulated and observed time series data. + - `sim_mean`: The mean of the simulated time series data. + - `obs_mean`: The mean of the observed time series data. + - `sim_std`: The standard deviation of the simulated time series data. + - `obs_std`: The standard deviation of the observed time series data. + - `nse`: The Nash-Sutcliffe efficiency between the simulated and observed time series data. + - `lamba`: The lambda statistic between the simulated and observed time series data. + - `cr`: The correlation coefficient between the simulated and observed time series data. + - `slope`: The slope of the linear regression between the simulated and observed time series data. + - `intercept`: The intercept of the linear regression between the simulated and observed time series data. + - `slope_pp`: The slope of the linear regression between the percentiles of the simulated and observed time series data. + - `intercept_pp`: The intercept of the linear regression between the percentiles of the simulated and observed time series data. + - `mad`: The median absolute deviation of the simulated time series data from its median. + - `madp`: The median absolute deviation of the simulated time series data from its median, calculated using the percentiles of the observed time series data. + - `madc`: The median absolute deviation of the simulated time series data from its median, calculated by adding `mad` to `madp` + - `kge`: The Kling-Gupta efficiency between the simulated and observed time series data. + - `R1`: Difference between observed and modelled for the biggest storm + - `R1_norm`: Normalized R1 (R1 divided by observed value) + - `R3`: Average difference between observed and modelled for the three biggest storms + - `R3_norm`: Normalized R3 (R3 divided by observed value) + - `error`: Average difference between observed and modelled for all storms + - `error_norm`: Normalized error (error divided by observed value) + """ + + if metrics is None: + metrics = SUGGESTED_METRICS + elif metrics == ["all"]: + metrics = SUPPORTED_METRICS + else: + if not isinstance(metrics, list): + raise ValueError( + "metrics must be a list of supported variables e.g. SUGGESTED_METRICS, or None, or ['all']" + ) + + if quantile > 0: # signal subsetting is only to be done on general metrics + sim_ = sim[sim > sim.quantile(quantile)] + obs_ = obs[obs > obs.quantile(quantile)] + else: + sim_ = sim + obs_ = obs + + stats = {} + + for metric in metrics: + match metric: + case "bias": + stats["bias"] = get_bias(sim_, obs_) + case "rmse": + stats["rmse"] = get_rmse(sim_, obs_) + case "rms": + stats["rms"] = get_rms(sim_, obs_) + case "sim_mean": + stats["sim_mean"] = sim_.mean() + case "obs_mean": + stats["obs_mean"] = obs_.mean() + case "sim_std": + stats["sim_std"] = sim_.std() + case "obs_std": + stats["obs_std"] = obs_.std() + case "mae": + stats["mae"] = get_mae(sim_, obs_) + case "mse": + stats["mse"] = get_mse(sim_, obs_) + case "nse": + stats["nse"] = get_nse(sim_, obs_) + case "lamba": + stats["lamba"] = get_lambda(sim_, obs_) + case "cr": + stats["cr"] = get_corr(sim_, obs_) + case "slope": + stats["slope"], _ = get_slope_intercept(sim_, obs_) + case "intercept": + _, stats["intercept"] = get_slope_intercept(sim_, obs_) + case "slope_pp": + stats["slope_pp"], _ = get_slope_intercept_pp(sim_, obs_) + case "intercept_pp": + _, stats["intercept_pp"] = get_slope_intercept_pp(sim_, obs_) + case "mad": + stats["mad"] = get_mad(sim_, obs_) + case "madp": + stats["madp"] = get_madp(sim_, obs_) + case "madc": + stats["madc"] = get_madc(sim_, obs_) + case "kge": + stats["kge"] = get_kge(sim_, obs_) + + # Storm metrics part with PoT Selection + if np.any([m in STORM_METRICS_ALL for m in metrics]): + df = match_extremes(sim, obs, quantile=quantile, cluster=cluster) + + for metric in metrics: + match metric: + case "R1": + stats["R1"] = df["error"].iloc[0] + case "R1_norm": + stats["R1_norm"] = df["error_norm"].iloc[0] + case "R3": + stats["R3"] = df["error"].iloc[0:3].mean() + case "R3_norm": + stats["R3_norm"] = df["error_norm"].iloc[0:3].mean() + case "error": + stats["error"] = df["error"].mean() + case "error_norm": + stats["error_norm"] = df["error_norm"].mean() + else: + logger.info("no storm metric specified") + + if round > 0: + for metric in metrics: + stats[metric] = np.round(stats[metric]) + + return stats diff --git a/seastats/stats.py b/seastats/stats.py index 45d5424..f6b216f 100644 --- a/seastats/stats.py +++ b/seastats/stats.py @@ -1,38 +1,12 @@ from __future__ import annotations import logging -from collections.abc import Sequence import numpy as np import pandas as pd logger = logging.getLogger(__name__) -GENERAL_METRICS_ALL = [ - "bias", - "rmse", - "rms", - "rms_qm", - "sim_mean", - "obs_mean", - "sim_std", - "obs_std", - "nse", - "lamba", - "cr", - "cr_qm", - "slope", - "intercept", - "slope_pp", - "intercept_pp", - "mad", - "madp", - "madc", - "kge", -] - -GENERAL_METRICS = ["bias", "rms", "rmse", "cr", "nse", "kge"] - def get_bias(sim: pd.Series, obs: pd.Series) -> float: return sim.mean() - obs.mean() @@ -170,107 +144,3 @@ def get_slope_intercept_pp(sim: pd.Series, obs: pd.Series) -> tuple[float, float pc1, pc2 = get_percentiles(sim, obs) slope, intercept = get_slope_intercept(pc1, pc2) return slope, intercept - - -def get_stats( - sim: pd.Series, - obs: pd.Series, - metrics: Sequence[str] = None, - quantile: float = 0, - round: int = -1, -) -> dict[str, float]: - """ - Calculates various statistical metrics between the simulated and observed time series data. - - Parameters: - sim (pd.Series): The simulated time series data. - obs (pd.Series): The observed time series data. - metrics (str/list, optional): The list of statistical metrics to calculate. If None, all metrics will be calculated. Default is None. - quantile (float, optional): Quantile used to calculate the metrics. Default is 0 (no selection) - round (int, optional): Apply rounding to the results to. Default is no rounding (value is -1) - - Returns: - Dict[str, float]: A dictionary containing the calculated statistical metrics. - - The dictionary contains the following keys and their corresponding values: - - - `bias`: The bias between the simulated and observed time series data. - - `rmse`: The root mean square error between the simulated and observed time series data. - - `rms`: The raw mean square error between the simulated and observed time series data. - - `sim_mean`: The mean of the simulated time series data. - - `obs_mean`: The mean of the observed time series data. - - `sim_std`: The standard deviation of the simulated time series data. - - `obs_std`: The standard deviation of the observed time series data. - - `nse`: The Nash-Sutcliffe efficiency between the simulated and observed time series data. - - `lamba`: The lambda statistic between the simulated and observed time series data. - - `cr`: The correlation coefficient between the simulated and observed time series data. - - `slope`: The slope of the linear regression between the simulated and observed time series data. - - `intercept`: The intercept of the linear regression between the simulated and observed time series data. - - `slope_pp`: The slope of the linear regression between the percentiles of the simulated and observed time series data. - - `intercept_pp`: The intercept of the linear regression between the percentiles of the simulated and observed time series data. - - `mad`: The median absolute deviation of the simulated time series data from its median. - - `madp`: The median absolute deviation of the simulated time series data from its median, calculated using the percentiles of the observed time series data. - - `madc`: The median absolute deviation of the simulated time series data from its median, calculated by adding `mad` to `madp` - - `kge`: The Kling-Gupta efficiency between the simulated and observed time series data. - """ - - if metrics is None: - metrics = GENERAL_METRICS - elif metrics == ["all"]: - metrics = GENERAL_METRICS_ALL - else: - if not isinstance(metrics, list): - raise ValueError( - "metrics must be a list of supported variables e.g. GENERAL_METRICS, or None, or ['all']" - ) - - if quantile > 0: - sim = sim[sim > sim.quantile(quantile)] - sim = obs[obs > obs.quantile(quantile)] - - stats = {} - - for metric in metrics: - match metric: - case "bias": - stats["bias"] = get_bias(sim, obs) - case "rmse": - stats["rmse"] = get_rmse(sim, obs) - case "rms": - stats["rms"] = get_rms(sim, obs) - case "sim_mean": - stats["sim_mean"] = sim.mean() - case "obs_mean": - stats["obs_mean"] = obs.mean() - case "sim_std": - stats["sim_std"] = sim.std() - case "obs_std": - stats["obs_std"] = obs.std() - case "nse": - stats["nse"] = get_nse(sim, obs) - case "lamba": - stats["lamba"] = get_lambda(sim, obs) - case "cr": - stats["cr"] = get_corr(sim, obs) - case "slope": - stats["slope"], _ = get_slope_intercept(sim, obs) - case "intercept": - _, stats["intercept"] = get_slope_intercept(sim, obs) - case "slope_pp": - stats["slope_pp"], _ = get_slope_intercept_pp(sim, obs) - case "intercept_pp": - _, stats["intercept_pp"] = get_slope_intercept_pp(sim, obs) - case "mad": - stats["mad"] = get_mad(sim, obs) - case "madp": - stats["madp"] = get_madp(sim, obs) - case "madc": - stats["madc"] = get_madc(sim, obs) - case "kge": - stats["kge"] = get_kge(sim, obs) - - if round > 0: - for metric in metrics: - stats[metric] = np.round(stats[metric]) - - return stats diff --git a/seastats/storms.py b/seastats/storms.py index 77bce27..5b5e7e3 100644 --- a/seastats/storms.py +++ b/seastats/storms.py @@ -1,15 +1,7 @@ -from __future__ import annotations - -from collections.abc import Sequence - import numpy as np import pandas as pd from pyextremes import get_extremes -STORM_METRICS = ["R1", "R3", "error"] - -STORM_METRICS_ALL = ["R1", "R1_norm", "R3", "R3_norm", "error" "error_norm"] - def match_extremes( sim: pd.Series, obs: pd.Series, quantile: float, cluster: int = 72 @@ -73,59 +65,3 @@ def match_extremes( df["tdiff"] = df["time model"] - df["time observed"] df["tdiff"] = df["tdiff"].apply(lambda x: x.total_seconds() / 3600) return df - - -def storm_metrics( - sim: pd.Series, - obs: pd.Series, - quantile: float, - cluster: int = 72, - metrics: Sequence[str] = None, -) -> dict[str, float]: - """ - Calculate metrics for comparing simulated and observed storm events - Parameters: - - sim (pd.Series): Simulated storm series - - obs (pd.Series): Observed storm series - - quantile (float): Quantile value for defining extreme events - - cluster (int, optional): Cluster duration for grouping storm events. Default is 72 hours - - Returns: - - Dict[str, float]: Dictionary containing calculated metrics: - - R1: Difference between observed and modelled for the biggest storm - - R1_norm: Normalized R1 (R1 divided by observed value) - - R3: Average difference between observed and modelled for the three biggest storms - - R3_norm: Normalized R3 (R3 divided by observed value) - - error: Average difference between observed and modelled for all storms - - error_norm: Normalized error (error divided by observed value) - """ - df = match_extremes(sim, obs, quantile=quantile, cluster=cluster) - - if metrics is None: - metrics = STORM_METRICS - elif metrics == ["all"]: - metrics = STORM_METRICS_ALL - else: - if not isinstance(metrics, list): - raise ValueError( - "metrics must be a list of supported variables e.g. STORM_METRICS, or None, or ['all']" - ) - - storm_stats = {} - - for metric in STORM_METRICS: - match metric: - case "R1": - storm_stats["R1"] = df["error"].iloc[0] - case "R1_norm": - storm_stats["R1_norm"] = df["error_norm"].iloc[0] - case "R3": - storm_stats["R3"] = df["error"].iloc[0:3].mean() - case "R3_norm": - storm_stats["R3_norm"] = df["error_norm"].iloc[0:3].mean() - case "error": - storm_stats["error"] = df["error"].mean() - case "error_norm": - storm_stats["error_norm"] = df["error_norm"].mean() - - return storm_stats diff --git a/tests/compute_stats_test.py b/tests/compute_stats_test.py index f1d6b00..e6737c2 100644 --- a/tests/compute_stats_test.py +++ b/tests/compute_stats_test.py @@ -1,11 +1,7 @@ import pandas as pd import pytest -from seastats import GENERAL_METRICS -from seastats import GENERAL_METRICS_ALL -from seastats import STORM_METRICS_ALL -from seastats.stats import get_stats -from seastats.storms import storm_metrics +from seastats import get_stats SIM = pd.read_parquet("tests/data/abed_sim.parquet") SIM = SIM[SIM.columns[0]] @@ -34,20 +30,19 @@ # Define test cases test_cases = [ - ("all_stats_and_extremes", 0.99, RESULTS), - ("all_stats", None, {m: RESULTS[m] for m in GENERAL_METRICS}), + ( + "all_stats_and_extremes", + {"quantile": 0, "metrics": ["rmse", "rms", "nse", "lamba", "cr", "slope"]}, + ), + ("all_stats", {"quantile": 0.99, "metrics": ["R1", "R3", "error"]}), ] -@pytest.mark.parametrize("test_type, quantile, expected", test_cases) -def test_metrics(test_type, quantile, expected): - if test_type == "all_stats_and_extremes": - stats = get_stats(SIM, OBS, metrics=GENERAL_METRICS_ALL) - sm = storm_metrics(SIM, OBS, quantile=quantile, metrics=STORM_METRICS_ALL) - stats = dict(stats, **sm) - elif test_type == "all_stats": - stats = get_stats(SIM, OBS) +@pytest.mark.parametrize("test_type, args", test_cases) +def test_metrics(test_type, args): + stats = get_stats(SIM, OBS, **args) # Assert all metrics - for metric, value in expected.items(): + check_dict = {m: RESULTS[m] for m in args["metrics"]} + for metric, value in check_dict.items(): assert stats[metric] == pytest.approx(value, abs=1.0e-3)