Skip to content

Commit

Permalink
update rm searcher_kwargs key from yaml for user
Browse files Browse the repository at this point in the history
  • Loading branch information
danrgll committed Jun 14, 2024
1 parent 12bc542 commit 5f3e541
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 30 deletions.
13 changes: 2 additions & 11 deletions neps/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def run(
if inspect.isclass(searcher):
if issubclass(searcher, BaseOptimizer):
search_space = SearchSpace(**pipeline_space)
searcher = searcher(search_space)
searcher = searcher(search_space, **searcher_kwargs)
else:
# Raise an error if searcher is not a subclass of BaseOptimizer
raise TypeError(
Expand Down Expand Up @@ -410,7 +410,7 @@ def _run_args(
raise TypeError(message) from e

if isinstance(searcher, (str, Path)) and searcher not in \
SearcherConfigs.get_searchers() and searcher is not "default":
SearcherConfigs.get_searchers() and searcher != "default":
# The users has their own custom searcher.
logging.info("Preparing to run user created searcher")

Expand Down Expand Up @@ -450,15 +450,6 @@ def _run_args(
# Map the old 'algorithm' argument to 'strategy'
config['strategy'] = config.pop("algorithm")

# Check for deprecated 'algorithm' argument
if "algorithm" in config:
warnings.warn(
"The 'algorithm' argument is deprecated and will be removed in future versions. Please use 'strategy' instead.",
DeprecationWarning
)
# Map the old 'algorithm' argument to 'strategy'
config['strategy'] = config.pop("algorithm")

if "strategy" in config:
searcher_alg = config.pop("strategy")
else:
Expand Down
22 changes: 14 additions & 8 deletions neps/utils/run_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_run_args_from_yaml(path: str) -> dict:
settings = {}

# List allowed NEPS run arguments with simple types (e.g., string, int). Parameters
# like 'searcher_kwargs', 'run_pipeline', 'preload_hooks', 'pipeline_space',
# like 'run_pipeline', 'preload_hooks', 'pipeline_space',
# and 'searcher' are excluded due to needing specialized processing.
expected_parameters = [
ROOT_DIRECTORY,
Expand Down Expand Up @@ -147,7 +147,7 @@ def config_loader(path: str) -> dict:

def extract_leaf_keys(d: dict, special_keys: dict | None = None) -> tuple[dict, dict]:
"""Recursive function to extract leaf keys and their values from a nested dictionary.
Special keys (e.g. 'searcher_kwargs', 'run_pipeline') are also extracted if present
Special keys (e.g.'run_pipeline') are also extracted if present
and their corresponding values (dict) at any level in the nested structure.
:param d: The dictionary to extract values from.
Expand All @@ -157,7 +157,6 @@ def extract_leaf_keys(d: dict, special_keys: dict | None = None) -> tuple[dict,
"""
if special_keys is None:
special_keys = {
SEARCHER_KWARGS: None,
RUN_PIPELINE: None,
PRE_LOAD_HOOKS: None,
SEARCHER: None,
Expand All @@ -183,8 +182,8 @@ def handle_special_argument_cases(settings: dict, special_configs: dict) -> None
This function updates 'settings' with values from 'special_configs'. It handles
specific keys that require more complex processing, such as 'pipeline_space' and
'searcher', which may need to load a function/dict from paths. It also manages nested
configurations like 'searcher_kwargs' and 'pre_load_hooks' which need individual
processing or function loading.
configurations like 'pre_load_hooks' which need individual processing or function
loading.
Parameters:
- settings (dict): The dictionary to be updated with processed configurations.
Expand Down Expand Up @@ -269,8 +268,12 @@ def process_searcher(key: str, special_configs: dict, settings: dict):
# determine if dict contains path_loading or the actual searcher config
expected_keys = {"path", "name"}
actual_keys = set(searcher.keys())
if expected_keys == actual_keys:
searcher = load_and_return_object(searcher["path"], searcher["name"], key)
if expected_keys.issubset(actual_keys):
path = searcher.pop("path")
name = searcher.pop("name")
settings[SEARCHER_KWARGS] = searcher
searcher = load_and_return_object(path, name, key)

elif isinstance(searcher, (str, Path)):
pass
else:
Expand Down Expand Up @@ -426,7 +429,7 @@ def check_run_args(settings: dict) -> None:
if not all(callable(item) for item in value):
raise TypeError("All items in 'pre_load_hooks' must be callable.")
elif param == SEARCHER:
if not (isinstance(param, str) or issubclass(param, BaseOptimizer)):
if not (isinstance(param, (str, dict)) or issubclass(param, BaseOptimizer)):
raise TypeError(
"Parameter 'searcher' must be a string or a class that is a subclass "
"of BaseOptimizer."
Expand Down Expand Up @@ -510,6 +513,9 @@ def check_double_reference(
if name == RUN_ARGS:
# Ignoring run_args argument
continue
if name == SEARCHER_KWARGS and func_arguments[name] == {}:
continue

if name in yaml_arguments:
raise ValueError(
f"Conflict for argument '{name}': Argument is defined both via "
Expand Down
19 changes: 8 additions & 11 deletions tests/test_yaml_run_args/test_yaml_run_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def are_functions_equivalent(f1: Union[Callable, List[Callable]],
"loss_value_on_error": 4.2,
"cost_value_on_error": 3.7,
"ignore_errors": True,
"searcher": "bayesian_optimization",
"searcher_kwargs": {"initial_design_size": 5, "surrogate_model": "gp"},
"searcher": {"strategy": "bayesian_optimization",
"initial_design_size": 5, "surrogate_model": "gp"},
"pre_load_hooks": [hook1, hook2],
},
),
Expand All @@ -133,8 +133,8 @@ def are_functions_equivalent(f1: Union[Callable, List[Callable]],
"loss_value_on_error": 2.4,
"cost_value_on_error": 2.1,
"ignore_errors": False,
"searcher": "bayesian_optimization",
"searcher_kwargs": {"initial_design_size": 5, "surrogate_model": "gp"},
"searcher": {"strategy": "bayesian_optimization",
"initial_design_size": 5, "surrogate_model": "gp"},
"pre_load_hooks": [hook1],
},
),
Expand All @@ -147,8 +147,8 @@ def are_functions_equivalent(f1: Union[Callable, List[Callable]],
"overwrite_working_directory": True,
"post_run_summary": False,
"continue_until_max_evaluation_completed": False,
"searcher": "bayesian_optimization",
"searcher_kwargs": {"initial_design_size": 5, "surrogate_model": "gp"},
"searcher": {"strategy": "bayesian_optimization",
"initial_design_size": 5, "surrogate_model": "gp"},
},
),
(
Expand Down Expand Up @@ -178,11 +178,8 @@ def are_functions_equivalent(f1: Union[Callable, List[Callable]],
"loss_value_on_error": 2.4,
"cost_value_on_error": 2.1,
"ignore_errors": False,
"searcher": BayesianOptimization,
"searcher_kwargs": {
"initial_design_size": 5,
"surrogate_model": "gp"
},
"searcher": {"strategy": "bayesian_optimization", "initial_design_size": 5,
"surrogate_model": "gp"},
"pre_load_hooks": [hook1]
})
Expand Down

0 comments on commit 5f3e541

Please sign in to comment.