Skip to content

Commit

Permalink
Commitin
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Dec 5, 2024
1 parent 0ba0626 commit a147425
Show file tree
Hide file tree
Showing 9 changed files with 403 additions and 480 deletions.
16 changes: 15 additions & 1 deletion src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import time
from fnmatch import fnmatch
from itertools import groupby
from typing import (
TYPE_CHECKING,
Callable,
Expand Down Expand Up @@ -256,6 +257,7 @@ def _load_observations_and_responses(
npt.NDArray[np.float64],
Tuple[
npt.NDArray[np.float64],
List[str],
npt.NDArray[np.float64],
List[ObservationAndResponseSnapshot],
],
Expand Down Expand Up @@ -349,7 +351,7 @@ def _load_observations_and_responses(
)
)

if len(scaling_factors_dfs):
if scaling_factors_dfs:
scaling_factors_df = polars.concat(scaling_factors_dfs)
ensemble.save_observation_scaling_factors(scaling_factors_df)

Expand Down Expand Up @@ -405,6 +407,7 @@ def _load_observations_and_responses(

return S[obs_mask], (
observations[obs_mask],
obs_keys[obs_mask],
scaled_errors[obs_mask],
update_snapshot,
)
Expand Down Expand Up @@ -548,6 +551,7 @@ def adaptive_localization_progress_callback(
S,
(
observation_values,
observation_keys,
observation_errors,
update_snapshot,
),
Expand All @@ -564,6 +568,14 @@ def adaptive_localization_progress_callback(
num_obs = len(observation_values)

smoother_snapshot.update_step_snapshots = update_snapshot
# Used as labels for observations in cross-correlation matrix.
# Say we have two observation groups "FOPR" and "WOPR" where "FOPR" has
# 2 responses and "WOPR" has 3.
# In this case we create a list [FOPR_0, FOPR_1, WOPR_0, WOPR_1, WOPR_2]
# as labels for observations.
unique_obs_names = [
f"{k}_{i}" for k, g in groupby(observation_keys) for i, _ in enumerate(list(g))
]

if num_obs == 0:
msg = "No active observations for update step"
Expand Down Expand Up @@ -667,6 +679,7 @@ def correlation_callback(
_cross_correlations,
param_group,
parameter_names[: _cross_correlations.shape[0]],
unique_obs_names,
)
logger.info(
f"Adaptive Localization of {param_group} completed in {(time.time() - start) / 60} minutes"
Expand Down Expand Up @@ -729,6 +742,7 @@ def analysis_IES(
S,
(
observation_values,
_,
observation_errors,
update_snapshot,
),
Expand Down
37 changes: 23 additions & 14 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numpy as np
import pandas as pd
import polars as pl
import xarray as xr
from pydantic import BaseModel
from typing_extensions import deprecated
Expand Down Expand Up @@ -559,16 +560,15 @@ def load_parameters(

return self._load_dataset(group, realizations)

def load_cross_correlations(self) -> xr.Dataset:
input_path = self.mount_point / "corr_XY.nc"

def load_cross_correlations(self) -> pl.DataFrame:
input_path = self.mount_point / "corr_XY.parquet"
if not input_path.exists():
raise FileNotFoundError(
f"No cross-correlation data available at '{input_path}'. Make sure to run the update with "
"Adaptive Localization enabled."
)
logger.info("Loading cross correlations")
return xr.open_dataset(input_path, engine="scipy")
return pl.read_parquet(input_path)

@require_write
def save_observation_scaling_factors(self, dataset: polars.DataFrame) -> None:
Expand All @@ -591,17 +591,26 @@ def save_cross_correlations(
cross_correlations: npt.NDArray[np.float64],
param_group: str,
parameter_names: List[str],
unique_obs_names: List[str],
) -> None:
data_vars = {
param_group: xr.DataArray(
data=cross_correlations,
dims=["parameter", "response"],
coords={"parameter": parameter_names},
)
}
dataset = xr.Dataset(data_vars)
file_path = os.path.join(self.mount_point, "corr_XY.nc")
self._storage._to_netcdf_transaction(file_path, dataset)
n_responses = cross_correlations.shape[1]
new_df = pl.DataFrame(
{
"param_group": [param_group]
* (len(parameter_names) * len(unique_obs_names)),
"param_name": np.repeat(parameter_names, n_responses),
"observation": unique_obs_names * len(parameter_names),
"value": cross_correlations.ravel(),
}
)

file_path = os.path.join(self.mount_point, "corr_XY.parquet")
if os.path.exists(file_path):
existing_df = pl.read_parquet(file_path)
df = pl.concat([existing_df, new_df])
else:
df = new_df
self._storage._to_parquet_transaction(file_path, df)

@lru_cache # noqa: B019
def load_responses(self, key: str, realizations: Tuple[int]) -> polars.DataFrame:
Expand Down
192 changes: 192 additions & 0 deletions test-data/ert/heat_equation/Plot_correlations.ipynb

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions test-data/ert/heat_equation/config.ert
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@ QUEUE_OPTION LOCAL MAX_RUNNING 100

RANDOM_SEED 11223344

ANALYSIS_SET_VAR STD_ENKF LOCALIZATION True
ANALYSIS_SET_VAR STD_ENKF LOCALIZATION_CORRELATION_THRESHOLD 0.1

NUM_REALIZATIONS 100
GRID CASE.EGRID

OBS_CONFIG observations

FIELD COND PARAMETER cond.bgrdecl INIT_FILES:cond.bgrdecl FORWARD_INIT:True

GEN_KW INIT_TEMP_SCALE init_temp_prior.txt
GEN_KW CORR_LENGTH corr_length_prior.txt

GEN_DATA MY_RESPONSE RESULT_FILE:gen_data_%d.out REPORT_STEPS:10,71,132,193,255,316,377,438 INPUT_FORMAT:ASCII

INSTALL_JOB heat_equation HEAT_EQUATION
Expand Down
1 change: 1 addition & 0 deletions test-data/ert/heat_equation/corr_length_prior.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x NORMAL 0.8 0.1
23 changes: 19 additions & 4 deletions test-data/ert/heat_equation/heat_equation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
"""Partial Differential Equations to use as forward models."""

import json
import sys
from typing import Optional

Expand Down Expand Up @@ -52,16 +53,28 @@ def heat_equation(
return _u


def sample_prior_conductivity(ensemble_size, nx, rng):
def sample_prior_conductivity(ensemble_size, nx, rng, corr_length):
mesh = np.meshgrid(np.linspace(0, 1, nx), np.linspace(0, 1, nx))
return np.exp(geostat.gaussian_fields(mesh, rng, ensemble_size, r=0.8))
return np.exp(geostat.gaussian_fields(mesh, rng, ensemble_size, r=corr_length))


def load_parameters(filename):
with open(filename, encoding="utf-8") as f:
return json.load(f)


if __name__ == "__main__":
iens = int(sys.argv[1])
iteration = int(sys.argv[2])
rng = np.random.default_rng(iens)
cond = sample_prior_conductivity(ensemble_size=1, nx=nx, rng=rng).reshape(nx, nx)

parameters = load_parameters("parameters.json")
init_temp_scale = parameters["INIT_TEMP_SCALE"]
corr_length = parameters["CORR_LENGTH"]

cond = sample_prior_conductivity(
ensemble_size=1, nx=nx, rng=rng, corr_length=float(corr_length["x"])
).reshape(nx, nx)

if iteration == 0:
resfo.write(
Expand All @@ -79,7 +92,9 @@ def sample_prior_conductivity(ensemble_size, nx, rng):
# Note that this could be avoided if we used an implicit solver.
dt = dx**2 / (4 * max(np.max(cond), np.max(cond)))

response = heat_equation(u_init, cond, dx, dt, k_start, k_end, rng)
scaled_u_init = u_init * float(init_temp_scale["x"])

response = heat_equation(scaled_u_init, cond, dx, dt, k_start, k_end, rng)

index = sorted((obs.x, obs.y) for obs in obs_coordinates)
for time_step in obs_times:
Expand Down
1 change: 1 addition & 0 deletions test-data/ert/heat_equation/init_temp_prior.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x UNIFORM 0 1
540 changes: 105 additions & 435 deletions test-data/ert/poly_example/Plot_correlations.ipynb

Large diffs are not rendered by default.

67 changes: 41 additions & 26 deletions tests/ert/ui_tests/cli/test_field_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,53 @@
from .run_cli import run_cli


def test_field_param_update_using_heat_equation(heat_equation_storage):
config = ErtConfig.from_file("config.ert")
with open_storage(config.ens_path, mode="w") as storage:
def test_shared_heat_equation_storage(heat_equation_storage):
"""The fixture heat_equation_storage runs the heat equation test case.
This test verifies that results are as expected.
"""
config = heat_equation_storage
with open_storage(config.ens_path) as storage:
experiment = storage.get_experiment_by_name("es-mda")
prior = experiment.get_ensemble_by_name("default_0")
posterior = experiment.get_ensemble_by_name("default_1")

prior_result = prior.load_parameters("COND")["values"]
ensembles = [experiment.get_ensemble_by_name(f"default_{i}") for i in range(4)]

param_config = config.ensemble_config.parameter_configs["COND"]
assert len(prior_result.x) == param_config.nx
assert len(prior_result.y) == param_config.ny
assert len(prior_result.z) == param_config.nz

posterior_result = posterior.load_parameters("COND")["values"]
prior_covariance = np.cov(
prior_result.values.reshape(
prior.ensemble_size, param_config.nx * param_config.ny * param_config.nz
),
rowvar=False,
)
posterior_covariance = np.cov(
posterior_result.values.reshape(
posterior.ensemble_size,
# Check that generalized variance decreases across consecutive ensembles
covariances = []
for ensemble in ensembles:
results = ensemble.load_parameters("COND")["values"]
reshaped_values = results.values.reshape(
ensemble.ensemble_size,
param_config.nx * param_config.ny * param_config.nz,
),
rowvar=False,
)
# Check that generalized variance is reduced by update step.
assert np.trace(prior_covariance) > np.trace(posterior_covariance)
)
covariances.append(np.cov(reshaped_values, rowvar=False))
for i in range(len(covariances) - 1):
assert np.trace(covariances[i]) > np.trace(
covariances[i + 1]
), f"Generalized variance did not decrease from iteration {i} to {i + 1}"

# Check that the saved cross-correlations are as expected.
for i in range(3):
ensemble = ensembles[i]
corr_XY = ensemble.load_cross_correlations()

assert corr_XY["param_group"].unique().to_list() == [
"INIT_TEMP_SCALE",
"CORR_LENGTH",
]
assert corr_XY["param_name"].unique().to_list() == ["x"]

# Make sure correlations are between -1 and 1.
is_valid = (corr_XY["value"] >= -1) & (corr_XY["value"] <= 1)
assert is_valid.all()

# Check observation keys
expected_observation_keys = [obs[0] for obs in config.observation_config]
observation_keys = corr_XY["observation"].unique().to_list()
observation_keys = {key[:-2] for key in observation_keys}
assert sorted(observation_keys) == sorted(expected_observation_keys)

# Check that fields in the runpath are different between iterations
# Check that fields in the runpath are different between ensembles
cond_iter0 = resfo.read("simulations/realization-0/iter-0/cond.bgrdecl")[0][1]
cond_iter1 = resfo.read("simulations/realization-0/iter-1/cond.bgrdecl")[0][1]
assert (cond_iter0 != cond_iter1).all()
Expand Down

0 comments on commit a147425

Please sign in to comment.