diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index 667213e2b42..963c0ebbe9d 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -21,7 +21,7 @@ from ert.config import GenKwConfig from ..config.analysis_config import ObservationGroups, UpdateSettings -from ..config.analysis_module import ESSettings, IESSettings +from ..config.analysis_module import BaseSettings, ESSettings, IESSettings from . import misfit_preprocessor from .event import ( AnalysisCompleteEvent, @@ -747,8 +747,8 @@ def smoother_update( posterior_storage: Ensemble, observations: Iterable[str], parameters: Iterable[str], - analysis_config: UpdateSettings | None = None, - es_settings: ESSettings | None = None, + analysis_config: UpdateSettings, + es_settings: BaseSettings, rng: np.random.Generator | None = None, progress_callback: Callable[[AnalysisEvent], None] | None = None, global_scaling: float = 1.0, @@ -757,8 +757,12 @@ def smoother_update( progress_callback = noop_progress_callback if rng is None: rng = np.random.default_rng() - analysis_config = UpdateSettings() if analysis_config is None else analysis_config - es_settings = ESSettings() if es_settings is None else es_settings + + assert isinstance(es_settings, ESSettings) + + # analysis_config = UpdateSettings() if analysis_config is None else analysis_config + # es_settings = ESSettings() if es_settings is None else es_settings + ens_mask = prior_storage.get_realization_mask_with_responses() smoother_snapshot = _create_smoother_snapshot( diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index bd6392647de..ad8ec4cd44a 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -31,7 +31,8 @@ AnalysisDataEvent, AnalysisErrorEvent, ) -from ert.config import ErtConfig, ESSettings, HookRuntime, QueueSystem +from ert.config import ErtConfig, HookRuntime, QueueSystem +from ert.config.analysis_module import BaseSettings from ert.enkf_main import _seed_sequence, create_run_path from ert.ensemble_evaluator import Ensemble as EEEnsemble from ert.ensemble_evaluator import ( @@ -732,7 +733,7 @@ def _evaluate_and_postprocess( class UpdateRunModel(BaseRunModel): def __init__( self, - es_settings: ESSettings, + analysis_settings: BaseSettings, update_settings: UpdateSettings, config: ErtConfig, storage: Storage, @@ -744,8 +745,9 @@ def __init__( random_seed: int | None, minimum_required_realizations: int, ): - self.es_settings = es_settings - self.update_settings = update_settings + self._analysis_settings: BaseSettings = analysis_settings + self._update_settings: UpdateSettings = update_settings + super().__init__( config, storage, @@ -786,8 +788,8 @@ def update( smoother_update( prior, posterior, - analysis_config=self.update_settings, - es_settings=self.es_settings, + analysis_config=self._update_settings, + es_settings=self._analysis_settings, parameters=prior.experiment.update_parameters, observations=prior.experiment.observation_keys, global_scaling=weight, diff --git a/tests/ert/performance_tests/test_memory_usage.py b/tests/ert/performance_tests/test_memory_usage.py index d1e0bb24ad0..abbe8d98ccf 100644 --- a/tests/ert/performance_tests/test_memory_usage.py +++ b/tests/ert/performance_tests/test_memory_usage.py @@ -13,7 +13,8 @@ import xtgeo from ert.analysis import smoother_update -from ert.config import ErtConfig +from ert.config import ErtConfig, ESSettings +from ert.config.analysis_config import UpdateSettings from ert.enkf_main import sample_prior from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE from ert.storage import open_storage @@ -65,6 +66,8 @@ def test_memory_smoothing(poly_template): posterior_ens, list(experiment.observation_keys), list(ert_config.ensemble_config.parameters), + UpdateSettings(), + ESSettings(), ) stats = memray._memray.compute_statistics(str(poly_template / "memray.bin")) diff --git a/tests/ert/performance_tests/test_obs_and_responses_performance.py b/tests/ert/performance_tests/test_obs_and_responses_performance.py index 8c70728b415..0e28f4ba42e 100644 --- a/tests/ert/performance_tests/test_obs_and_responses_performance.py +++ b/tests/ert/performance_tests/test_obs_and_responses_performance.py @@ -8,7 +8,8 @@ import pytest from ert.analysis import smoother_update -from ert.config import GenDataConfig, GenKwConfig, SummaryConfig +from ert.config import ESSettings, GenDataConfig, GenKwConfig, SummaryConfig +from ert.config.analysis_config import UpdateSettings from ert.config.gen_kw_config import TransformFunctionDefinition from ert.enkf_main import sample_prior from ert.storage import open_storage @@ -506,6 +507,8 @@ def test_memory_performance_of_doing_es_update(setup_es_benchmark, tmp_path): posterior, prior.experiment.observation_keys, [gen_kw_name], + UpdateSettings(), + ESSettings(), ) stats = memray._memray.compute_statistics(str(tmp_path / "memray.bin")) @@ -525,6 +528,8 @@ def run(): posterior, prior.experiment.observation_keys, [gen_kw_name], + UpdateSettings(), + ESSettings(), ) benchmark(run) diff --git a/tests/ert/unit_tests/scenarios/test_summary_response.py b/tests/ert/unit_tests/scenarios/test_summary_response.py index 7e4102beff9..0cf556272bb 100644 --- a/tests/ert/unit_tests/scenarios/test_summary_response.py +++ b/tests/ert/unit_tests/scenarios/test_summary_response.py @@ -11,7 +11,8 @@ from ert import LibresFacade from ert.analysis import ErtAnalysisError, smoother_update -from ert.config import ErtConfig +from ert.config import ErtConfig, ESSettings +from ert.config.analysis_config import UpdateSettings from ert.data import MeasuredData from ert.enkf_main import sample_prior @@ -102,6 +103,8 @@ def test_that_reading_matching_time_is_ok(ert_config, storage, prior_ensemble): target_ensemble, prior_ensemble.experiment.observation_keys, ert_config.ensemble_config.parameters, + UpdateSettings(), + ESSettings(), ) @@ -129,6 +132,8 @@ def test_that_mismatched_responses_give_error(ert_config, storage, prior_ensembl target_ensemble, prior_ensemble.experiment.observation_keys, ert_config.ensemble_config.parameters, + UpdateSettings(), + ESSettings(), ) @@ -160,6 +165,8 @@ def test_that_different_length_is_ok_as_long_as_observation_time_exists( target_ensemble, prior_ensemble.experiment.observation_keys, ert_config.ensemble_config.parameters, + UpdateSettings(), + ESSettings(), ) @@ -206,6 +213,8 @@ def test_that_duplicate_summary_time_steps_does_not_fail( target_ensemble, prior_ensemble.experiment.observation_keys, ert_config.ensemble_config.parameters, + UpdateSettings(), + ESSettings(), ) diff --git a/tests/ert/unit_tests/storage/test_storage_migration.py b/tests/ert/unit_tests/storage/test_storage_migration.py index 952e6b16055..1721b954a55 100644 --- a/tests/ert/unit_tests/storage/test_storage_migration.py +++ b/tests/ert/unit_tests/storage/test_storage_migration.py @@ -10,7 +10,8 @@ from packaging import version from ert.analysis import ErtAnalysisError, smoother_update -from ert.config import ErtConfig +from ert.config import ErtConfig, ESSettings +from ert.config.analysis_config import UpdateSettings from ert.storage import open_storage from ert.storage.local_storage import ( _LOCAL_STORAGE_VERSION, @@ -467,6 +468,8 @@ def test_that_manual_update_from_migrated_storage_works( posterior_ens, list(experiment.observation_keys), list(ert_config.ensemble_config.parameters), + UpdateSettings(), + ESSettings(), )