From ff566dd17d7415a392bf1ae4fa70cd823db24893 Mon Sep 17 00:00:00 2001 From: Daniel <63580393+danrgll@users.noreply.github.com> Date: Mon, 17 Jun 2024 11:32:40 +0200 Subject: [PATCH] add providing arguments for loaded class BaseOptimizer via yaml + fix pre-commit related errors --- neps/api.py | 10 ++++++---- neps/utils/run_args.py | 16 ++++++++-------- tests/test_yaml_run_args/test_yaml_run_args.py | 13 +++++++------ 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/neps/api.py b/neps/api.py index e91d2ec2..3f35bd88 100644 --- a/neps/api.py +++ b/neps/api.py @@ -224,7 +224,6 @@ def run( if run_args: optim_settings = get_run_args_from_yaml(run_args) check_double_reference(run, locals(), optim_settings) - run_pipeline = optim_settings.get("run_pipeline", run_pipeline) root_directory = optim_settings.get("root_directory", root_directory) pipeline_space = optim_settings.get("pipeline_space", pipeline_space) @@ -250,8 +249,8 @@ def run( cost_value_on_error) pre_load_hooks = optim_settings.get("pre_load_hooks", pre_load_hooks) searcher = optim_settings.get("searcher", searcher) - for key, value in optim_settings.get("searcher_kwargs", searcher_kwargs).items(): - searcher_kwargs[key] = value + # considers arguments of a provided SubClass of BaseOptimizer + searcher_class_arguments = optim_settings.get("custom_class_searcher_kwargs", {}) # check if necessary arguments are provided. check_essential_arguments( @@ -283,7 +282,10 @@ def run( if inspect.isclass(searcher): if issubclass(searcher, BaseOptimizer): search_space = SearchSpace(**pipeline_space) - searcher = searcher(search_space, **searcher_kwargs) + # aligns with the behavior of the internal neps searcher which also overwrites + # its arguments by using searcher_kwargs + merge_kwargs = {**searcher_class_arguments, **searcher_kwargs} + searcher = searcher(search_space, **merge_kwargs) else: # Raise an error if searcher is not a subclass of BaseOptimizer raise TypeError( diff --git a/neps/utils/run_args.py b/neps/utils/run_args.py index 0f5c005f..46c12414 100644 --- a/neps/utils/run_args.py +++ b/neps/utils/run_args.py @@ -35,7 +35,9 @@ IGNORE_ERROR = "ignore_errors" SEARCHER = "searcher" PRE_LOAD_HOOKS = "pre_load_hooks" -SEARCHER_KWARGS = "searcher_kwargs" +# searcher_kwargs is used differently in yaml and just play a role for considering +# arguments of a custom searcher class (BaseOptimizer) +SEARCHER_KWARGS = "custom_class_searcher_kwargs" MAX_EVALUATIONS_PER_RUN = "max_evaluations_per_run" @@ -229,6 +231,7 @@ def process_pipeline_space(key: str, special_configs: dict, settings: dict) -> N """ if special_configs.get(key) is not None: pipeline_space = special_configs[key] + # Define the type of processed_pipeline_space to accommodate both situations if isinstance(pipeline_space, dict): # determine if dict contains path_loading or the actual search space expected_keys = {"path", "name"} @@ -240,7 +243,7 @@ def process_pipeline_space(key: str, special_configs: dict, settings: dict) -> N # pipeline_space stored in a python dict, not using a yaml processed_pipeline_space = load_and_return_object( pipeline_space["path"], pipeline_space["name"], key - ) + ) # type: ignore elif isinstance(pipeline_space, str): # load yaml from path processed_pipeline_space = pipeline_space_from_yaml(pipeline_space) @@ -333,7 +336,7 @@ def load_and_return_object(module_path: str, object_name: str, key: str) -> obje the issue. """ - def import_object(path): + def import_object(path: str) -> object | None: try: # Convert file system path to module path, removing '.py' if present. module_name = ( @@ -433,7 +436,7 @@ def check_run_args(settings: dict) -> None: if not all(callable(item) for item in value): raise TypeError("All items in 'pre_load_hooks' must be callable.") elif param == SEARCHER: - if not (isinstance(param, (str, dict)) or issubclass(param, BaseOptimizer)): + if not (isinstance(value, (str, dict)) or issubclass(value, BaseOptimizer)): raise TypeError( "Parameter 'searcher' must be a string or a class that is a subclass " "of BaseOptimizer." @@ -443,7 +446,7 @@ def check_run_args(settings: dict) -> None: expected_type = expected_types[param] except KeyError as e: raise KeyError(f"{param} is not a valid argument of neps") from e - if not isinstance(value, expected_type): + if not isinstance(value, expected_type): # type: ignore raise TypeError( f"Parameter '{param}' expects a value of type {expected_type}, got " f"{type(value)} instead." @@ -517,9 +520,6 @@ def check_double_reference( if name == RUN_ARGS: # Ignoring run_args argument continue - if name == SEARCHER_KWARGS and func_arguments[name] == {}: - continue - if name in yaml_arguments: raise ValueError( f"Conflict for argument '{name}': Argument is defined both via " diff --git a/tests/test_yaml_run_args/test_yaml_run_args.py b/tests/test_yaml_run_args/test_yaml_run_args.py index 99f6d070..3ce2185f 100644 --- a/tests/test_yaml_run_args/test_yaml_run_args.py +++ b/tests/test_yaml_run_args/test_yaml_run_args.py @@ -17,14 +17,14 @@ def run_pipeline(): return -def hook1(): +def hook1(sampler): """func to test loading of pre_load_hooks""" - return + return sampler -def hook2(): +def hook2(sampler): """func to test loading of pre_load_hooks""" - return + return sampler def check_run_args(yaml_path_run_args: str, expected_output: Dict) -> None: @@ -178,8 +178,9 @@ def are_functions_equivalent(f1: Union[Callable, List[Callable]], "loss_value_on_error": 2.4, "cost_value_on_error": 2.1, "ignore_errors": False, - "searcher": {"strategy": "bayesian_optimization", "initial_design_size": 5, - "surrogate_model": "gp"}, + "searcher": BayesianOptimization, + "custom_class_searcher_kwargs": {'initial_design_size': 5, + 'surrogate_model': 'gp'}, "pre_load_hooks": [hook1] })