From 302725962c7c74b914e4544580ebfe4f8eea0e16 Mon Sep 17 00:00:00 2001 From: DanSava Date: Mon, 2 Dec 2024 17:32:22 +0200 Subject: [PATCH] Replace usage of load_all_summary_data() in everest data api --- src/everest/api/everest_data_api.py | 51 +++++++++++++++-------------- src/everest/export.py | 15 ++++++++- tests/everest/test_api_snapshots.py | 2 +- 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/src/everest/api/everest_data_api.py b/src/everest/api/everest_data_api.py index b56009c5467..d3b3d804fab 100644 --- a/src/everest/api/everest_data_api.py +++ b/src/everest/api/everest_data_api.py @@ -1,6 +1,6 @@ from collections import OrderedDict -import pandas as pd +import polars as pl from seba_sqlite.snapshot import SebaSnapshot from ert.storage import open_storage @@ -154,48 +154,49 @@ def gradient_values(self): def summary_values(self, batches=None, keys=None): if batches is None: batches = self.batches - simulations = self.simulations data_frames = [] storage = open_storage(self._config.storage_dir, "r") experiment = next(storage.experiments) for batch_id in batches: ensemble = experiment.get_ensemble_by_name(f"batch_{batch_id}") - summary = ensemble.load_all_summary_data() - if not summary.empty: - columns = set(summary.columns) - if keys is not None: - columns = columns.intersection(set(keys)) - summary = summary[list(columns)] - summary = summary.dropna(axis=0, how="all", subset=columns) - summary = summary.dropna(axis=1, how="all") - summary = summary[ - summary.index.get_level_values("Realization").isin(simulations) - ] - summary.reset_index(inplace=True) - summary["batch"] = batch_id + try: + summary = ensemble.load_responses( + key="summary", + realizations=tuple(self.simulations), + ) + except (ValueError, KeyError): + summary = pl.DataFrame() + + if not summary.is_empty(): + summary = summary.pivot( + on="response_key", index=["realization", "time"], sort_columns=True + ) # The 'Realization' column exported by ert are # the 'simulations' of everest. - summary.rename( - columns={"Realization": "simulation", "Date": "date"}, inplace=True + summary = summary.rename({"time": "date", "realization": "simulation"}) + if keys is not None: + columns = set(summary.columns).intersection(set(keys)) + summary = summary[["date", "simulation", *list(columns)]] + summary = summary.with_columns( + pl.Series("batch", [batch_id] * summary.shape[0]) ) # The realization ID as defined by Everest must be # retrieved via the seba snapshot. realization_map = { - str(sim.simulation): sim.realization + sim.simulation: sim.realization for sim in self._snapshot.simulation_data if sim.batch == batch_id } - summary["realization"] = ( - summary["simulation"].astype(str).map(realization_map) - ) - # If possible, convert the realization id to integer. - summary["realization"] = pd.to_numeric( - summary["realization"], errors="ignore", downcast="integer" + realizations = pl.Series( + "realization", + [realization_map.get(str(sim)) for sim in summary["simulation"]], ) + realizations = realizations.cast(pl.Int64, strict=False) + summary = summary.with_columns(realizations) data_frames.append(summary) storage.close() - return pd.concat(data_frames) + return pl.concat(data_frames) @property def output_folder(self): diff --git a/src/everest/export.py b/src/everest/export.py index af729ace764..5f10771a307 100644 --- a/src/everest/export.py +++ b/src/everest/export.py @@ -336,7 +336,20 @@ def load_batch_by_id(): experiment = experiments[0] ensemble = experiment.get_ensemble_by_name(case_name) - return ensemble.load_all_summary_data() + realizations = ensemble.get_realization_list_with_responses() + try: + df_pl = ensemble.load_responses("summary", tuple(realizations)) + except (ValueError, KeyError): + return pd.DataFrame() + df_pl = df_pl.pivot( + on="response_key", index=["realization", "time"], sort_columns=True + ) + df_pl = df_pl.rename({"time": "Date", "realization": "Realization"}) + return ( + df_pl.to_pandas() + .set_index(["Realization", "Date"]) + .sort_values(by=["Date", "Realization"]) + ) batches = {elem[MetaDataColumnNames.BATCH] for elem in metadata} batch_data = [] diff --git a/tests/everest/test_api_snapshots.py b/tests/everest/test_api_snapshots.py index 4b2d1fcfac2..441f8f8f398 100644 --- a/tests/everest/test_api_snapshots.py +++ b/tests/everest/test_api_snapshots.py @@ -133,7 +133,7 @@ def test_api_summary_snapshot( ens.save_response("summary", smry_data.clone(), real) api = EverestDataAPI(config) - dicts = polars.from_pandas(api.summary_values()).to_dicts() + dicts = api.summary_values().to_dicts() snapshot.assert_match( orjson.dumps(dicts, option=orjson.OPT_INDENT_2).decode("utf-8").strip() + "\n", "snapshot.json",