Skip to content

Commit 9b4ae81

Browse files
committed
Dedupe obs&responses join logic in MeasuredData
The goal of this is to remove MeasuredData. By deduplicating the logic, we can have external dependents not use MeasuredData itself, but instead copypaste the logic around LocalEnsemble.get_observations_and_responses, which essentially just renames some columns on a dataframe and translates it from polars to pandas.
1 parent a8d875e commit 9b4ae81

File tree

3 files changed

+44
-68
lines changed

3 files changed

+44
-68
lines changed

src/ert/data/_measured_data.py

Lines changed: 31 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import numpy as np
1414
import pandas as pd
15-
import polars as pl
1615

1716
if TYPE_CHECKING:
1817
from ert.storage import Ensemble
@@ -83,83 +82,51 @@ def is_empty(self) -> bool:
8382
@staticmethod
8483
def _get_data(
8584
ensemble: Ensemble,
86-
observation_keys: list[str],
85+
observed_response_keys: list[str],
8786
) -> pd.DataFrame:
8887
"""
8988
Adds simulated and observed data and returns a dataframe where ensemble
9089
members will have a data key, observed data will be named OBS and
9190
observed standard deviation will be named STD.
9291
"""
9392

94-
observations_by_type = ensemble.experiment.observations
93+
resp_key_to_resp_type = ensemble.experiment.response_key_to_response_type
94+
selected_response_types = {
95+
response_type
96+
for response_key, response_type in resp_key_to_resp_type.items()
97+
if response_key in observed_response_keys
98+
}
9599

96-
dfs = []
100+
active_realizations = ensemble.get_realization_list_with_responses()
97101

98-
for key in observation_keys:
99-
if key not in ensemble.experiment.observation_keys:
100-
raise ObservationError(
101-
f"No observation: {key} in ensemble: {ensemble.name}"
102-
)
103-
104-
for (
105-
response_type,
106-
response_cls,
107-
) in ensemble.experiment.response_configuration.items():
108-
observations_for_type = observations_by_type[response_type].filter(
109-
pl.col("observation_key").is_in(observation_keys)
110-
)
111-
responses_for_type = ensemble.load_responses(
112-
response_type,
113-
realizations=tuple(ensemble.get_realization_list_with_responses()),
114-
)
115-
116-
if responses_for_type.is_empty():
102+
# Check if responses exist for all selected response types
103+
for response_type in selected_response_types:
104+
df = ensemble.load_responses(response_type, tuple(active_realizations))
105+
if df.is_empty():
117106
raise ResponseError(
118107
f"No response loaded for observation type: {response_type}"
119108
)
120109

121-
# Note that if there are duplicate entries for one
122-
# response at one index, they are aggregated together
123-
# with "mean" by default
124-
pivoted = responses_for_type.pivot(
125-
on="realization",
126-
index=["response_key", *response_cls.primary_key],
127-
aggregate_function="mean",
110+
df = (
111+
ensemble.get_observations_and_responses(
112+
observed_response_keys, np.array(active_realizations)
128113
)
129-
130-
if "time" in pivoted:
131-
joined = observations_for_type.join_asof(
132-
pivoted,
133-
by=["response_key", *response_cls.primary_key],
134-
on="time",
135-
tolerance="1s",
136-
)
137-
else:
138-
joined = observations_for_type.join(
139-
pivoted,
140-
how="left",
141-
on=["response_key", *response_cls.primary_key],
142-
)
143-
144-
joined = joined.sort(by="observation_key").with_columns(
145-
pl.concat_str(response_cls.primary_key, separator=", ").alias(
146-
"key_index"
147-
)
114+
.rename(
115+
{
116+
"index": "key_index",
117+
"observations": "OBS",
118+
"std": "STD",
119+
}
148120
)
149-
150-
# Put key_index column 1st
151-
joined = joined[["key_index", *joined.columns[:-1]]]
152-
joined = joined.drop(*response_cls.primary_key)
153-
154-
if not joined.is_empty():
155-
dfs.append(joined)
156-
157-
df = pl.concat(dfs)
158-
df = df.rename(
159-
{
160-
"observations": "OBS",
161-
"std": "STD",
162-
}
121+
.select(
122+
"key_index",
123+
"response_key",
124+
"observation_key",
125+
"OBS",
126+
"STD",
127+
*map(str, active_realizations),
128+
)
129+
.sort(by="observation_key")
163130
)
164131

165132
pddf = df.to_pandas()[
@@ -175,7 +142,7 @@ def _get_data(
175142
# Pandas differentiates vs int and str keys.
176143
# Legacy-wise we use int keys for realizations
177144
pddf.rename(
178-
columns={str(k): int(k) for k in range(ensemble.ensemble_size)},
145+
columns={str(k): int(k) for k in active_realizations},
179146
inplace=True,
180147
)
181148

src/ert/storage/local_ensemble.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,15 @@ def get_observations_and_responses(
974974
) -> pl.DataFrame:
975975
"""Fetches and aligns selected observations with their
976976
corresponding simulated responses from an ensemble."""
977+
known_observations = self.experiment.observation_keys
978+
unknown_observations = [
979+
obs for obs in selected_observations if obs not in known_observations
980+
]
981+
982+
if unknown_observations:
983+
msg = f"Observations: {', '.join(unknown_observations)} not in experiment"
984+
raise KeyError(msg)
985+
977986
observations_by_type = self.experiment.observations
978987

979988
with pl.StringCache():

tests/ert/unit_tests/data/test_integration_data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ert.config import ErtConfig
1010
from ert.data import MeasuredData
11-
from ert.data._measured_data import ObservationError, ResponseError
11+
from ert.data._measured_data import ResponseError
1212
from ert.libres_facade import LibresFacade
1313
from ert.storage import open_storage
1414

@@ -95,8 +95,8 @@ def test_gen_obs_and_summary(create_measured_data):
9595
@pytest.mark.parametrize(
9696
"obs_key, expected_msg",
9797
[
98-
("FOPR", r"No observation: FOPR in ensemble"),
99-
("WPR_DIFF_1", "No observation: WPR_DIFF_1 in ensemble"),
98+
("FOPR", r"Observations: FOPR not in experiment"),
99+
("WPR_DIFF_1", "Observations: WPR_DIFF_1 not in experiment"),
100100
],
101101
)
102102
def test_no_storage(obs_key, expected_msg, storage):
@@ -105,7 +105,7 @@ def test_no_storage(obs_key, expected_msg, storage):
105105
)
106106

107107
with pytest.raises(
108-
ObservationError,
108+
KeyError,
109109
match=expected_msg,
110110
):
111111
MeasuredData(ensemble, [obs_key])

0 commit comments

Comments
 (0)