Skip to content

Commit

Permalink
✨ update error extraction (seaborn 0.12.2 and 0.13.2) (#80)
Browse files Browse the repository at this point in the history
* ✨ update error extraction (seaborn 0.12.2 and 0.13.2)

- check that everything works with 0.12.2 before testing switching to 0.13.2

* ⬆️ remove seaborn upper limit

- next will be to remove matplotlibs upper limit

* ⬆️ remove upper matplotlib limit

- after updating to seaborn 0.13 this should be save

* 🐛 test on PR branch the tutorial

- PR should test branch version
- schedule should test that PyPI version works
- development branch was deleted (Trunk Based Development)
  • Loading branch information
enryH authored Sep 2, 2024
1 parent cfbb8e6 commit 6f391c0
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 59 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/test_pkg_on_colab.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ jobs:
image: europe-docker.pkg.dev/colab-images/public/runtime:latest
steps:
- uses: actions/checkout@v4
- name: Install pimms-learn and papermill
- name: Install pimms-learn (from branch) and papermill
if: github.event_name == 'pull_request'
run: |
python3 -m pip install pimms-learn papermill
- name: Install pimms-learn (from PyPI) and papermill
if: github.event_name == 'schedule'
run: |
python3 -m pip install pimms-learn papermill
- name: Run tutorial
Expand Down
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ dependencies:
- pandas>=1
- scipy>=1.6
# plotting
- matplotlib>=3.4,<3.9
- matplotlib>=3.4
- python-kaleido
- plotly
- seaborn<0.13
- seaborn
- pip
# ML
- pytorch #=1.13.1=py3.8_cuda11.7_cudnn8_0
Expand Down
76 changes: 29 additions & 47 deletions pimmslearn/plotting/errors.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""Plot errors based on DataFrame with model predictions."""
from __future__ import annotations

import itertools
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from seaborn.categorical import _BarPlotter
from seaborn.categorical import EstimateAggregator


import pimmslearn.pandas.calc_errors

Expand Down Expand Up @@ -109,52 +107,36 @@ def plot_errors_by_median(pred: pd.DataFrame,
return ax, errors


def get_data_for_errors_by_median(errors: pd.DataFrame, feat_name, metric_name, seed=None):
"""Extract Bars with confidence intervals from seaborn plot.
Confident intervals are calculated with bootstrapping (sampling the mean).
Relies on internal seaborn class. only used for reporting of source data in the paper.
def get_data_for_errors_by_median(errors: pd.DataFrame,
feat_name: str,
metric_name: str,
model_column: str = 'model',
seed: int = 42) -> pd.DataFrame:
"""Extract Bars with confidence intervals from seaborn plot for seaborn 0.13 and above.
Confident intervals are calculated with bootstrapping(sampling the mean).
Parameters
----------
errors: pd.DataFrame
DataFrame created by `plot_errors_by_median` function
feat_name: str
feature name assigned(was transformed to 'intensity binned by median of {feat_name}')
metric_name: str
Metric used to calculate errors(MAE, MSE, etc) of intensities in bin
model_column: str
model_column in errors, defining model names
"""
x_axis_name = f'intensity binned by median of {feat_name}'
aggregator = EstimateAggregator("mean", ("ci", 95), n_boot=1_000, seed=seed)
# ! need to iterate over all models myself using groupby
ret = (errors
.groupby(by=[x_axis_name, model_column,], observed=True)
[[x_axis_name, model_column, metric_name]]
.apply(lambda df: aggregator(df, metric_name))
.reset_index())
ret.columns = ["bin", model_column, "mean", "ci_low", "ci_high"]
return ret

plotter = _BarPlotter(data=errors, x=x_axis_name, y=metric_name, hue='model',
order=None, hue_order=None,
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=seed,
orient=None, color=None, palette=None, saturation=.75, width=.8,
errcolor=".26", errwidth=None, capsize=None, dodge=True)
ax = plt.gca()
plotter.plot(ax, {})
plt.close(ax.get_figure())
mean, cf_interval = plotter.statistic.flatten(), plotter.confint.reshape(-1, 2)
plotted = pd.DataFrame(np.concatenate((mean.reshape(-1, 1), cf_interval), axis=1), columns=[
'mean', 'ci_low', 'ci_high'])
_index = pd.DataFrame(list(itertools.product(
(_l.get_text() for _l in ax.get_xticklabels()), # bins x-axis
(_l.get_text() for _l in ax.get_legend().get_texts()), # models legend
)
), columns=['bin', 'model'])
plotted = pd.concat([_index, plotted], axis=1)
return plotted


# def get_data_for_errors_by_median_v2(errors: pd.DataFrame, feat_name, metric_name):
# from seaborn._statistics import (
# EstimateAggregator,
# WeightedAggregator,
# )
# from seaborn.categorical import _CategoricalAggPlotter, WeightedAggregator, EstimateAggregator
# p = _CategoricalAggPlotter(
# data=data,
# variables=dict(x=x, y=y, hue=hue, units=units, weight=weights),
# order=order,
# orient=orient,
# color=color,
# legend=legend,
# )

# agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator
# aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)
# err_kws = {} if err_kws is None else normalize_kwargs(err_kws, mpl.lines.Line2D)


def plot_rolling_error(errors: pd.DataFrame, metric_name: str, window: int = 200,
Expand Down
4 changes: 2 additions & 2 deletions project/04_1_train_pimms_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
" print(f\"Running in colab and pimms-learn ({_v}) is installed.\")\n",
" except metadata.PackageNotFoundError:\n",
" print(\"Install PIMMS...\")\n",
" # !pip install git+https://github.com/RasmussenLab/pimms.git@dev\n",
" # !pip install git+https://github.com/RasmussenLab/pimms.git\n",
" !pip install pimms-learn"
]
},
Expand Down Expand Up @@ -364,7 +364,7 @@
"metadata": {},
"outputs": [],
"source": [
"CollaborativeFilteringTransformer?"
"# # CollaborativeFilteringTransformer?"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions project/04_1_train_pimms_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
print(f"Running in colab and pimms-learn ({_v}) is installed.")
except metadata.PackageNotFoundError:
print("Install PIMMS...")
# # !pip install git+https://github.com/RasmussenLab/pimms.git@dev
# # !pip install git+https://github.com/RasmussenLab/pimms.git
# !pip install pimms-learn

# %% [markdown]
Expand Down Expand Up @@ -167,7 +167,7 @@
# Inspect annotations of the scikit-learn like Transformer:

# %%
# # CollaborativeFilteringTransformer?
# # # CollaborativeFilteringTransformer?

# %% [markdown]
# Let's set up collaborative filtering without a validation or test set, using
Expand Down
4 changes: 2 additions & 2 deletions project/workflow/envs/pimms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ dependencies:
- pandas>=1
- scipy>=1.6
# plotting
- matplotlib<3.9
- matplotlib
- python-kaleido
- plotly
- seaborn<0.13
- seaborn
- pip
# ML
- pytorch #=1.13.1=py3.8_cuda11.7_cudnn8_0
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ classifiers = [
dependencies = [
"njab>=0.0.8",
"numpy",
"matplotlib<3.9",
"matplotlib",
"pandas",
"plotly",
"torch",
"scikit-learn>=1.0",
"scipy",
"seaborn<0.13",
"seaborn",
"fastai",
"omegaconf",
"tqdm",
Expand Down
3 changes: 2 additions & 1 deletion tests/plotting/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def expected_plotted():
plotted_path = file_dir / 'expected_plotted.csv'
# ! Windows reads in new line in string characters as '\r\n'
df = pd.read_csv(plotted_path, sep=',', index_col=0)
df["bin"] = df["bin"].str.replace('\r\n', '\n')
df["bin"] = df["bin"].str.replace('\r\n', '\n').astype('category')
df = df.sort_values(by=['bin', 'model']).reset_index(drop=True)
return df


Expand Down

0 comments on commit 6f391c0

Please sign in to comment.