Skip to content

Commit

Permalink
address @pmav99 reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsail committed Nov 8, 2024
1 parent 19f803e commit b3b8f36
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 52 deletions.
140 changes: 89 additions & 51 deletions seastats/stats.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,35 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import numpy as np
import pandas as pd

_ALL_METRICS = [
"bias",
"rmse",
"rms",
"rms_95",
"sim_mean",
"obs_mean",
"sim_std",
"obs_std",
"nse",
"lamba",
"cr",
"cr_95",
"slope",
"intercept",
"slope_pp",
"intercept_pp",
"mad",
"madp",
"madc",
"kge",
]


def get_bias(sim: pd.Series, obs: pd.Series, round: int = 3) -> float:
bias = sim.mean() - obs.mean()
Expand Down Expand Up @@ -179,10 +202,10 @@ def get_slope_intercept_pp(


def get_stats(
sim: pd.Series, # The simulated time series data.
obs: pd.Series, # The observed time series data.
metrics: Optional[Union[str, List[str]]] = None,
round: int = 3, # The number of decimal places to round the results to. Default is 3.
sim: pd.Series,
obs: pd.Series,
metrics: Sequence[str] = ["all"],
round: int = 3,
) -> Dict[str, float]:
"""
Calculates various statistical metrics between the simulated and observed time series data.
Expand Down Expand Up @@ -220,52 +243,67 @@ def get_stats(
- `kge`: The Kling-Gupta efficiency between the simulated and observed time series data.
"""

def should_calculate(metric_name):
return metrics is None or metric_name in metrics_list
if metrics == ["all"]:
metrics = _ALL_METRICS
print(metrics)

metrics_list = [metrics] if isinstance(metrics, str) else metrics
version_stat = {}

# Calculate each metric if it should be calculated
if should_calculate("bias"):
version_stat["bias"] = get_bias(sim, obs, round)
if should_calculate("rmse"):
version_stat["rmse"] = get_rmse(sim, obs, round)
if should_calculate("rms"):
version_stat["rms"] = get_rms(sim, obs, round)
if should_calculate("rms_95"):
version_stat["rms_95"] = get_rms_quantile(sim, obs, 0.95, round)
if should_calculate("sim_mean"):
version_stat["sim_mean"] = np.round(sim.mean(), round)
if should_calculate("obs_mean"):
version_stat["obs_mean"] = np.round(obs.mean(), round)
if should_calculate("sim_std"):
version_stat["sim_std"] = np.round(sim.std(), round)
if should_calculate("obs_std"):
version_stat["obs_std"] = np.round(obs.std(), round)
if should_calculate("nse"):
version_stat["nse"] = get_nse(sim, obs, round)
if should_calculate("lamba"):
version_stat["lamba"] = get_lambda(sim, obs, round)
if should_calculate("cr"):
version_stat["cr"] = get_corr(sim, obs, round)
if should_calculate("cr_95"):
version_stat["cr_95"] = get_corr_quantile(sim, obs, 0.95, round)
if should_calculate("slope") or should_calculate("intercept"):
slope, intercept = get_slope_intercept(sim, obs)
version_stat["slope"] = slope
version_stat["intercept"] = intercept
if should_calculate("slope_pp") or should_calculate("intercept_pp"):
slopepp, interceptpp = get_slope_intercept_pp(sim, obs)
version_stat["slope_pp"] = slopepp
version_stat["intercept_pp"] = interceptpp
if should_calculate("mad"):
version_stat["mad"] = get_mad(sim, obs, round)
if should_calculate("madp"):
version_stat["madp"] = get_madp(sim, obs, round)
if should_calculate("madc"):
version_stat["madc"] = get_madc(sim, obs, round)
if should_calculate("kge"):
version_stat["kge"] = get_kge(sim, obs, round)
for metric in metrics:
match metric:
case "bias":
version_stat["bias"] = get_bias(sim, obs, round)
case "rmse":
version_stat["rmse"] = get_rmse(sim, obs, round)
case "rms":
version_stat["rms"] = get_rms(sim, obs, round)
case "rms_95":
version_stat["rms_95"] = get_rms_quantile(sim, obs, 0.95, round)
case "sim_mean":
version_stat["sim_mean"] = np.round(sim.mean(), round)
case "obs_mean":
version_stat["obs_mean"] = np.round(obs.mean(), round)
case "sim_std":
version_stat["sim_std"] = np.round(sim.std(), round)
case "obs_std":
version_stat["obs_std"] = np.round(obs.std(), round)
case "nse":
version_stat["nse"] = get_nse(sim, obs, round)
case "lamba":
version_stat["lamba"] = get_lambda(sim, obs, round)
case "cr":
version_stat["cr"] = get_corr(sim, obs, round)
case "cr_95":
version_stat["cr_95"] = get_corr_quantile(sim, obs, 0.95, round)
case "slope":
slope, intercept = get_slope_intercept(sim, obs, round)
version_stat["slope"] = slope
version_stat["intercept"] = intercept
case "intercept":
if "slope" in metrics:
pass # because already computed
else:
slope, intercept = get_slope_intercept(sim, obs, round)
version_stat["slope"] = slope
version_stat["intercept"] = intercept
case "slope_pp":
slopepp, interceptpp = get_slope_intercept_pp(sim, obs, round)
version_stat["slope_pp"] = slopepp
version_stat["intercept_pp"] = interceptpp
case "intercept_pp":
if "slope_pp" in metrics:
pass # because already computed
else:
slopepp, interceptpp = get_slope_intercept_pp(sim, obs, round)
version_stat["slope_pp"] = slopepp
version_stat["intercept_pp"] = interceptpp
case "mad":
version_stat["mad"] = get_mad(sim, obs, round)
case "madp":
version_stat["madp"] = get_madp(sim, obs, round)
case "madc":
version_stat["madc"] = get_madc(sim, obs, round)
case "kge":
version_stat["kge"] = get_kge(sim, obs, round)

return version_stat
2 changes: 1 addition & 1 deletion tests/compute_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_get_rmse():
# sim and obs need to be Series
obs = obs[obs.columns[0]]
sim = sim[sim.columns[0]]
stats = get_stats(sim, obs, "rmse")
stats = get_stats(sim, obs, ["rmse"])
assert stats["rmse"] == 0.086


Expand Down

0 comments on commit b3b8f36

Please sign in to comment.