-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: streamlined API, tests and README file
- Loading branch information
Showing
5 changed files
with
278 additions
and
308 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,196 @@ | ||
from seastats.stats import GENERAL_METRICS | ||
from seastats.stats import GENERAL_METRICS_ALL | ||
from seastats.storms import STORM_METRICS | ||
from seastats.storms import STORM_METRICS_ALL | ||
from __future__ import annotations | ||
|
||
import logging | ||
from collections.abc import Sequence | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from seastats.stats import get_bias | ||
from seastats.stats import get_corr | ||
from seastats.stats import get_kge | ||
from seastats.stats import get_lambda | ||
from seastats.stats import get_mad | ||
from seastats.stats import get_madc | ||
from seastats.stats import get_madp | ||
from seastats.stats import get_mae | ||
from seastats.stats import get_mse | ||
from seastats.stats import get_nse | ||
from seastats.stats import get_rms | ||
from seastats.stats import get_rmse | ||
from seastats.stats import get_slope_intercept | ||
from seastats.stats import get_slope_intercept_pp | ||
from seastats.storms import match_extremes | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
GENERAL_METRICS_ALL = [ | ||
"bias", | ||
"rmse", | ||
"rms", | ||
"rms_qm", | ||
"sim_mean", | ||
"obs_mean", | ||
"sim_std", | ||
"obs_std", | ||
"mae", | ||
"mse", | ||
"nse", | ||
"lamba", | ||
"cr", | ||
"cr_qm", | ||
"slope", | ||
"intercept", | ||
"slope_pp", | ||
"intercept_pp", | ||
"mad", | ||
"madp", | ||
"madc", | ||
"kge", | ||
] | ||
GENERAL_METRICS = ["bias", "rms", "rmse", "cr", "nse", "kge"] | ||
STORM_METRICS = ["R1", "R3", "error"] | ||
STORM_METRICS_ALL = ["R1", "R1_norm", "R3", "R3_norm", "error" "error_norm"] | ||
|
||
SUGGESTED_METRICS = sorted(GENERAL_METRICS + STORM_METRICS) | ||
SUPPORTED_METRICS = sorted(GENERAL_METRICS_ALL + STORM_METRICS_ALL) | ||
|
||
|
||
def get_stats( | ||
sim: pd.Series, | ||
obs: pd.Series, | ||
metrics: Sequence[str] = None, | ||
quantile: float = 0, | ||
cluster: int = 72, | ||
round: int = -1, | ||
) -> dict[str, float]: | ||
""" | ||
Calculates various statistical metrics between the simulated and observed time series data. | ||
:param pd.Series sim: The simulated time series data. | ||
:param pd.Series obs: The observed time series data. | ||
:param list[str] metrics: (Optional) The list of statistical metrics to calculate. If None, all items in SUGGESTED_METRICS will be calculated. If ["all"], all items in SUPPORTED_METRICS will be calculated. Default is None. | ||
:param float quantile: (Optional) Quantile used to calculate the metrics. Default is 0 (no selection) | ||
:param int cluster: (Optional) Cluster duration for grouping storm events. Default is 72 hours. | ||
:param int round: (Optional) Apply rounding to the results to. Default is no rounding (value is -1) | ||
:return dict stats: dictionary containing the calculated metrics. | ||
The dictionary contains the following keys and their corresponding values: | ||
- `bias`: The bias between the simulated and observed time series data. | ||
- `rmse`: The Root Mean Square Error between the simulated and observed time series data. | ||
- `mae`: The Mean Absolute Error the simulated and observed time series data. | ||
- `mse`: The Mean Square Error the simulated and observed time series data. | ||
- `rms`: The raw mean square error between the simulated and observed time series data. | ||
- `sim_mean`: The mean of the simulated time series data. | ||
- `obs_mean`: The mean of the observed time series data. | ||
- `sim_std`: The standard deviation of the simulated time series data. | ||
- `obs_std`: The standard deviation of the observed time series data. | ||
- `nse`: The Nash-Sutcliffe efficiency between the simulated and observed time series data. | ||
- `lamba`: The lambda statistic between the simulated and observed time series data. | ||
- `cr`: The correlation coefficient between the simulated and observed time series data. | ||
- `slope`: The slope of the linear regression between the simulated and observed time series data. | ||
- `intercept`: The intercept of the linear regression between the simulated and observed time series data. | ||
- `slope_pp`: The slope of the linear regression between the percentiles of the simulated and observed time series data. | ||
- `intercept_pp`: The intercept of the linear regression between the percentiles of the simulated and observed time series data. | ||
- `mad`: The median absolute deviation of the simulated time series data from its median. | ||
- `madp`: The median absolute deviation of the simulated time series data from its median, calculated using the percentiles of the observed time series data. | ||
- `madc`: The median absolute deviation of the simulated time series data from its median, calculated by adding `mad` to `madp` | ||
- `kge`: The Kling-Gupta efficiency between the simulated and observed time series data. | ||
- `R1`: Difference between observed and modelled for the biggest storm | ||
- `R1_norm`: Normalized R1 (R1 divided by observed value) | ||
- `R3`: Average difference between observed and modelled for the three biggest storms | ||
- `R3_norm`: Normalized R3 (R3 divided by observed value) | ||
- `error`: Average difference between observed and modelled for all storms | ||
- `error_norm`: Normalized error (error divided by observed value) | ||
""" | ||
|
||
if metrics is None: | ||
metrics = SUGGESTED_METRICS | ||
elif metrics == ["all"]: | ||
metrics = SUPPORTED_METRICS | ||
else: | ||
if not isinstance(metrics, list): | ||
raise ValueError( | ||
"metrics must be a list of supported variables e.g. SUGGESTED_METRICS, or None, or ['all']" | ||
) | ||
|
||
if quantile > 0: # signal subsetting is only to be done on general metrics | ||
sim_ = sim[sim > sim.quantile(quantile)] | ||
obs_ = obs[obs > obs.quantile(quantile)] | ||
else: | ||
sim_ = sim | ||
obs_ = obs | ||
|
||
stats = {} | ||
|
||
for metric in metrics: | ||
match metric: | ||
case "bias": | ||
stats["bias"] = get_bias(sim_, obs_) | ||
case "rmse": | ||
stats["rmse"] = get_rmse(sim_, obs_) | ||
case "rms": | ||
stats["rms"] = get_rms(sim_, obs_) | ||
case "sim_mean": | ||
stats["sim_mean"] = sim_.mean() | ||
case "obs_mean": | ||
stats["obs_mean"] = obs_.mean() | ||
case "sim_std": | ||
stats["sim_std"] = sim_.std() | ||
case "obs_std": | ||
stats["obs_std"] = obs_.std() | ||
case "mae": | ||
stats["mae"] = get_mae(sim_, obs_) | ||
case "mse": | ||
stats["mse"] = get_mse(sim_, obs_) | ||
case "nse": | ||
stats["nse"] = get_nse(sim_, obs_) | ||
case "lamba": | ||
stats["lamba"] = get_lambda(sim_, obs_) | ||
case "cr": | ||
stats["cr"] = get_corr(sim_, obs_) | ||
case "slope": | ||
stats["slope"], _ = get_slope_intercept(sim_, obs_) | ||
case "intercept": | ||
_, stats["intercept"] = get_slope_intercept(sim_, obs_) | ||
case "slope_pp": | ||
stats["slope_pp"], _ = get_slope_intercept_pp(sim_, obs_) | ||
case "intercept_pp": | ||
_, stats["intercept_pp"] = get_slope_intercept_pp(sim_, obs_) | ||
case "mad": | ||
stats["mad"] = get_mad(sim_, obs_) | ||
case "madp": | ||
stats["madp"] = get_madp(sim_, obs_) | ||
case "madc": | ||
stats["madc"] = get_madc(sim_, obs_) | ||
case "kge": | ||
stats["kge"] = get_kge(sim_, obs_) | ||
|
||
# Storm metrics part with PoT Selection | ||
if np.any([m in STORM_METRICS_ALL for m in metrics]): | ||
df = match_extremes(sim, obs, quantile=quantile, cluster=cluster) | ||
|
||
for metric in metrics: | ||
match metric: | ||
case "R1": | ||
stats["R1"] = df["error"].iloc[0] | ||
case "R1_norm": | ||
stats["R1_norm"] = df["error_norm"].iloc[0] | ||
case "R3": | ||
stats["R3"] = df["error"].iloc[0:3].mean() | ||
case "R3_norm": | ||
stats["R3_norm"] = df["error_norm"].iloc[0:3].mean() | ||
case "error": | ||
stats["error"] = df["error"].mean() | ||
case "error_norm": | ||
stats["error_norm"] = df["error_norm"].mean() | ||
else: | ||
logger.info("no storm metric specified") | ||
|
||
if round > 0: | ||
for metric in metrics: | ||
stats[metric] = np.round(stats[metric]) | ||
|
||
return stats |
Oops, something went wrong.