Skip to content

Commit

Permalink
add providing arguments for loaded class BaseOptimizer via yaml + fix…
Browse files Browse the repository at this point in the history
… pre-commit related errors
  • Loading branch information
danrgll committed Jun 17, 2024
1 parent c06dde7 commit ff566dd
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
10 changes: 6 additions & 4 deletions neps/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def run(
if run_args:
optim_settings = get_run_args_from_yaml(run_args)
check_double_reference(run, locals(), optim_settings)

run_pipeline = optim_settings.get("run_pipeline", run_pipeline)
root_directory = optim_settings.get("root_directory", root_directory)
pipeline_space = optim_settings.get("pipeline_space", pipeline_space)
Expand All @@ -250,8 +249,8 @@ def run(
cost_value_on_error)
pre_load_hooks = optim_settings.get("pre_load_hooks", pre_load_hooks)
searcher = optim_settings.get("searcher", searcher)
for key, value in optim_settings.get("searcher_kwargs", searcher_kwargs).items():
searcher_kwargs[key] = value
# considers arguments of a provided SubClass of BaseOptimizer
searcher_class_arguments = optim_settings.get("custom_class_searcher_kwargs", {})

# check if necessary arguments are provided.
check_essential_arguments(
Expand Down Expand Up @@ -283,7 +282,10 @@ def run(
if inspect.isclass(searcher):
if issubclass(searcher, BaseOptimizer):
search_space = SearchSpace(**pipeline_space)
searcher = searcher(search_space, **searcher_kwargs)
# aligns with the behavior of the internal neps searcher which also overwrites
# its arguments by using searcher_kwargs
merge_kwargs = {**searcher_class_arguments, **searcher_kwargs}
searcher = searcher(search_space, **merge_kwargs)
else:
# Raise an error if searcher is not a subclass of BaseOptimizer
raise TypeError(
Expand Down
16 changes: 8 additions & 8 deletions neps/utils/run_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
IGNORE_ERROR = "ignore_errors"
SEARCHER = "searcher"
PRE_LOAD_HOOKS = "pre_load_hooks"
SEARCHER_KWARGS = "searcher_kwargs"
# searcher_kwargs is used differently in yaml and just play a role for considering
# arguments of a custom searcher class (BaseOptimizer)
SEARCHER_KWARGS = "custom_class_searcher_kwargs"
MAX_EVALUATIONS_PER_RUN = "max_evaluations_per_run"


Expand Down Expand Up @@ -229,6 +231,7 @@ def process_pipeline_space(key: str, special_configs: dict, settings: dict) -> N
"""
if special_configs.get(key) is not None:
pipeline_space = special_configs[key]
# Define the type of processed_pipeline_space to accommodate both situations
if isinstance(pipeline_space, dict):
# determine if dict contains path_loading or the actual search space
expected_keys = {"path", "name"}
Expand All @@ -240,7 +243,7 @@ def process_pipeline_space(key: str, special_configs: dict, settings: dict) -> N
# pipeline_space stored in a python dict, not using a yaml
processed_pipeline_space = load_and_return_object(
pipeline_space["path"], pipeline_space["name"], key
)
) # type: ignore
elif isinstance(pipeline_space, str):
# load yaml from path
processed_pipeline_space = pipeline_space_from_yaml(pipeline_space)
Expand Down Expand Up @@ -333,7 +336,7 @@ def load_and_return_object(module_path: str, object_name: str, key: str) -> obje
the issue.
"""

def import_object(path):
def import_object(path: str) -> object | None:
try:
# Convert file system path to module path, removing '.py' if present.
module_name = (
Expand Down Expand Up @@ -433,7 +436,7 @@ def check_run_args(settings: dict) -> None:
if not all(callable(item) for item in value):
raise TypeError("All items in 'pre_load_hooks' must be callable.")
elif param == SEARCHER:
if not (isinstance(param, (str, dict)) or issubclass(param, BaseOptimizer)):
if not (isinstance(value, (str, dict)) or issubclass(value, BaseOptimizer)):
raise TypeError(
"Parameter 'searcher' must be a string or a class that is a subclass "
"of BaseOptimizer."
Expand All @@ -443,7 +446,7 @@ def check_run_args(settings: dict) -> None:
expected_type = expected_types[param]
except KeyError as e:
raise KeyError(f"{param} is not a valid argument of neps") from e
if not isinstance(value, expected_type):
if not isinstance(value, expected_type): # type: ignore
raise TypeError(
f"Parameter '{param}' expects a value of type {expected_type}, got "
f"{type(value)} instead."
Expand Down Expand Up @@ -517,9 +520,6 @@ def check_double_reference(
if name == RUN_ARGS:
# Ignoring run_args argument
continue
if name == SEARCHER_KWARGS and func_arguments[name] == {}:
continue

if name in yaml_arguments:
raise ValueError(
f"Conflict for argument '{name}': Argument is defined both via "
Expand Down
13 changes: 7 additions & 6 deletions tests/test_yaml_run_args/test_yaml_run_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def run_pipeline():
return


def hook1():
def hook1(sampler):
"""func to test loading of pre_load_hooks"""
return
return sampler


def hook2():
def hook2(sampler):
"""func to test loading of pre_load_hooks"""
return
return sampler


def check_run_args(yaml_path_run_args: str, expected_output: Dict) -> None:
Expand Down Expand Up @@ -178,8 +178,9 @@ def are_functions_equivalent(f1: Union[Callable, List[Callable]],
"loss_value_on_error": 2.4,
"cost_value_on_error": 2.1,
"ignore_errors": False,
"searcher": {"strategy": "bayesian_optimization", "initial_design_size": 5,
"surrogate_model": "gp"},
"searcher": BayesianOptimization,
"custom_class_searcher_kwargs": {'initial_design_size': 5,
'surrogate_model': 'gp'},
"pre_load_hooks": [hook1]
})
Expand Down

0 comments on commit ff566dd

Please sign in to comment.