Skip to content

Commit

Permalink
Replace usage of load_all_summary_data() in everest data api
Browse files Browse the repository at this point in the history
  • Loading branch information
DanSava committed Dec 3, 2024
1 parent 22687ea commit 3027259
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 27 deletions.
51 changes: 26 additions & 25 deletions src/everest/api/everest_data_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import OrderedDict

import pandas as pd
import polars as pl
from seba_sqlite.snapshot import SebaSnapshot

from ert.storage import open_storage
Expand Down Expand Up @@ -154,48 +154,49 @@ def gradient_values(self):
def summary_values(self, batches=None, keys=None):
if batches is None:
batches = self.batches
simulations = self.simulations
data_frames = []
storage = open_storage(self._config.storage_dir, "r")
experiment = next(storage.experiments)
for batch_id in batches:
ensemble = experiment.get_ensemble_by_name(f"batch_{batch_id}")
summary = ensemble.load_all_summary_data()
if not summary.empty:
columns = set(summary.columns)
if keys is not None:
columns = columns.intersection(set(keys))
summary = summary[list(columns)]
summary = summary.dropna(axis=0, how="all", subset=columns)
summary = summary.dropna(axis=1, how="all")
summary = summary[
summary.index.get_level_values("Realization").isin(simulations)
]
summary.reset_index(inplace=True)
summary["batch"] = batch_id
try:
summary = ensemble.load_responses(
key="summary",
realizations=tuple(self.simulations),
)
except (ValueError, KeyError):
summary = pl.DataFrame()

if not summary.is_empty():
summary = summary.pivot(
on="response_key", index=["realization", "time"], sort_columns=True
)
# The 'Realization' column exported by ert are
# the 'simulations' of everest.
summary.rename(
columns={"Realization": "simulation", "Date": "date"}, inplace=True
summary = summary.rename({"time": "date", "realization": "simulation"})
if keys is not None:
columns = set(summary.columns).intersection(set(keys))
summary = summary[["date", "simulation", *list(columns)]]
summary = summary.with_columns(
pl.Series("batch", [batch_id] * summary.shape[0])
)
# The realization ID as defined by Everest must be
# retrieved via the seba snapshot.
realization_map = {
str(sim.simulation): sim.realization
sim.simulation: sim.realization
for sim in self._snapshot.simulation_data
if sim.batch == batch_id
}
summary["realization"] = (
summary["simulation"].astype(str).map(realization_map)
)
# If possible, convert the realization id to integer.
summary["realization"] = pd.to_numeric(
summary["realization"], errors="ignore", downcast="integer"
realizations = pl.Series(
"realization",
[realization_map.get(str(sim)) for sim in summary["simulation"]],
)
realizations = realizations.cast(pl.Int64, strict=False)
summary = summary.with_columns(realizations)

data_frames.append(summary)
storage.close()
return pd.concat(data_frames)
return pl.concat(data_frames)

@property
def output_folder(self):
Expand Down
15 changes: 14 additions & 1 deletion src/everest/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,20 @@ def load_batch_by_id():
experiment = experiments[0]

ensemble = experiment.get_ensemble_by_name(case_name)
return ensemble.load_all_summary_data()
realizations = ensemble.get_realization_list_with_responses()
try:
df_pl = ensemble.load_responses("summary", tuple(realizations))
except (ValueError, KeyError):
return pd.DataFrame()
df_pl = df_pl.pivot(
on="response_key", index=["realization", "time"], sort_columns=True
)
df_pl = df_pl.rename({"time": "Date", "realization": "Realization"})
return (
df_pl.to_pandas()
.set_index(["Realization", "Date"])
.sort_values(by=["Date", "Realization"])
)

batches = {elem[MetaDataColumnNames.BATCH] for elem in metadata}
batch_data = []
Expand Down
2 changes: 1 addition & 1 deletion tests/everest/test_api_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_api_summary_snapshot(
ens.save_response("summary", smry_data.clone(), real)

api = EverestDataAPI(config)
dicts = polars.from_pandas(api.summary_values()).to_dicts()
dicts = api.summary_values().to_dicts()
snapshot.assert_match(
orjson.dumps(dicts, option=orjson.OPT_INDENT_2).decode("utf-8").strip() + "\n",
"snapshot.json",
Expand Down

0 comments on commit 3027259

Please sign in to comment.