Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor update workflow signature #9688

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ert.config import GenKwConfig

from ..config.analysis_config import ObservationGroups, UpdateSettings
from ..config.analysis_module import ESSettings, IESSettings
from ..config.analysis_module import BaseSettings, ESSettings, IESSettings
from . import misfit_preprocessor
from .event import (
AnalysisCompleteEvent,
Expand Down Expand Up @@ -747,8 +747,8 @@ def smoother_update(
posterior_storage: Ensemble,
observations: Iterable[str],
parameters: Iterable[str],
analysis_config: UpdateSettings | None = None,
es_settings: ESSettings | None = None,
analysis_config: UpdateSettings,
es_settings: BaseSettings,
rng: np.random.Generator | None = None,
progress_callback: Callable[[AnalysisEvent], None] | None = None,
global_scaling: float = 1.0,
Expand All @@ -757,8 +757,9 @@ def smoother_update(
progress_callback = noop_progress_callback
if rng is None:
rng = np.random.default_rng()
analysis_config = UpdateSettings() if analysis_config is None else analysis_config
es_settings = ESSettings() if es_settings is None else es_settings

assert isinstance(es_settings, ESSettings)

ens_mask = prior_storage.get_realization_mask_with_responses()

smoother_snapshot = _create_smoother_snapshot(
Expand Down
14 changes: 8 additions & 6 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
AnalysisDataEvent,
AnalysisErrorEvent,
)
from ert.config import ErtConfig, ESSettings, HookRuntime, QueueSystem
from ert.config import ErtConfig, HookRuntime, QueueSystem
from ert.config.analysis_module import BaseSettings
from ert.enkf_main import _seed_sequence, create_run_path
from ert.ensemble_evaluator import Ensemble as EEEnsemble
from ert.ensemble_evaluator import (
Expand Down Expand Up @@ -732,7 +733,7 @@ def _evaluate_and_postprocess(
class UpdateRunModel(BaseRunModel):
def __init__(
self,
es_settings: ESSettings,
analysis_settings: BaseSettings,
update_settings: UpdateSettings,
config: ErtConfig,
storage: Storage,
Expand All @@ -744,8 +745,9 @@ def __init__(
random_seed: int | None,
minimum_required_realizations: int,
):
self.es_settings = es_settings
self.update_settings = update_settings
self._analysis_settings: BaseSettings = analysis_settings
self._update_settings: UpdateSettings = update_settings

super().__init__(
config,
storage,
Expand Down Expand Up @@ -786,8 +788,8 @@ def update(
smoother_update(
prior,
posterior,
analysis_config=self.update_settings,
es_settings=self.es_settings,
analysis_config=self._update_settings,
es_settings=self._analysis_settings,
parameters=prior.experiment.update_parameters,
observations=prior.experiment.observation_keys,
global_scaling=weight,
Expand Down
5 changes: 4 additions & 1 deletion tests/ert/performance_tests/test_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import xtgeo

from ert.analysis import smoother_update
from ert.config import ErtConfig
from ert.config import ErtConfig, ESSettings
from ert.config.analysis_config import UpdateSettings
from ert.enkf_main import sample_prior
from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE
from ert.storage import open_storage
Expand Down Expand Up @@ -65,6 +66,8 @@ def test_memory_smoothing(poly_template):
posterior_ens,
list(experiment.observation_keys),
list(ert_config.ensemble_config.parameters),
UpdateSettings(),
ESSettings(),
)

stats = memray._memray.compute_statistics(str(poly_template / "memray.bin"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pytest

from ert.analysis import smoother_update
from ert.config import GenDataConfig, GenKwConfig, SummaryConfig
from ert.config import ESSettings, GenDataConfig, GenKwConfig, SummaryConfig
from ert.config.analysis_config import UpdateSettings
from ert.config.gen_kw_config import TransformFunctionDefinition
from ert.enkf_main import sample_prior
from ert.storage import open_storage
Expand Down Expand Up @@ -506,6 +507,8 @@ def test_memory_performance_of_doing_es_update(setup_es_benchmark, tmp_path):
posterior,
prior.experiment.observation_keys,
[gen_kw_name],
UpdateSettings(),
ESSettings(),
)

stats = memray._memray.compute_statistics(str(tmp_path / "memray.bin"))
Expand All @@ -525,6 +528,8 @@ def run():
posterior,
prior.experiment.observation_keys,
[gen_kw_name],
UpdateSettings(),
ESSettings(),
)

benchmark(run)
11 changes: 10 additions & 1 deletion tests/ert/unit_tests/scenarios/test_summary_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from ert import LibresFacade
from ert.analysis import ErtAnalysisError, smoother_update
from ert.config import ErtConfig
from ert.config import ErtConfig, ESSettings
from ert.config.analysis_config import UpdateSettings
from ert.data import MeasuredData
from ert.enkf_main import sample_prior

Expand Down Expand Up @@ -102,6 +103,8 @@ def test_that_reading_matching_time_is_ok(ert_config, storage, prior_ensemble):
target_ensemble,
prior_ensemble.experiment.observation_keys,
ert_config.ensemble_config.parameters,
UpdateSettings(),
ESSettings(),
)


Expand Down Expand Up @@ -129,6 +132,8 @@ def test_that_mismatched_responses_give_error(ert_config, storage, prior_ensembl
target_ensemble,
prior_ensemble.experiment.observation_keys,
ert_config.ensemble_config.parameters,
UpdateSettings(),
ESSettings(),
)


Expand Down Expand Up @@ -160,6 +165,8 @@ def test_that_different_length_is_ok_as_long_as_observation_time_exists(
target_ensemble,
prior_ensemble.experiment.observation_keys,
ert_config.ensemble_config.parameters,
UpdateSettings(),
ESSettings(),
)


Expand Down Expand Up @@ -206,6 +213,8 @@ def test_that_duplicate_summary_time_steps_does_not_fail(
target_ensemble,
prior_ensemble.experiment.observation_keys,
ert_config.ensemble_config.parameters,
UpdateSettings(),
ESSettings(),
)


Expand Down
5 changes: 4 additions & 1 deletion tests/ert/unit_tests/storage/test_storage_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from packaging import version

from ert.analysis import ErtAnalysisError, smoother_update
from ert.config import ErtConfig
from ert.config import ErtConfig, ESSettings
from ert.config.analysis_config import UpdateSettings
from ert.storage import open_storage
from ert.storage.local_storage import (
_LOCAL_STORAGE_VERSION,
Expand Down Expand Up @@ -467,6 +468,8 @@ def test_that_manual_update_from_migrated_storage_works(
posterior_ens,
list(experiment.observation_keys),
list(ert_config.ensemble_config.parameters),
UpdateSettings(),
ESSettings(),
)


Expand Down
Loading