diff --git a/neps/utils/run_args.py b/neps/utils/run_args.py index 46c12414..71870c98 100644 --- a/neps/utils/run_args.py +++ b/neps/utils/run_args.py @@ -9,7 +9,7 @@ import logging import sys from pathlib import Path -from typing import Any, Callable +from typing import Callable import yaml @@ -45,13 +45,13 @@ def get_run_args_from_yaml(path: str) -> dict: """Load and validate NEPS run arguments from a specified YAML configuration file provided via run_args. - This function reads a YAML file, extracts the arguments required by NEPS, + This function reads a YAML file, extracts the arguments required by NePS, validates these arguments, and then returns them in a dictionary. It checks for the presence and validity of expected parameters, and distinctively handles more complex configurations, specifically those that are dictionaries(e.g. pipeline_space) or objects(e.g. run_pipeline) requiring loading. - Parameters: + Args: path (str): The file path to the YAML configuration file. Returns: @@ -66,7 +66,7 @@ def get_run_args_from_yaml(path: str) -> dict: # Initialize an empty dictionary to hold the extracted settings settings = {} - # List allowed NEPS run arguments with simple types (e.g., string, int). Parameters + # List allowed NePS run arguments with simple types (e.g., string, int). Parameters # like 'run_pipeline', 'preload_hooks', 'pipeline_space', # and 'searcher' are excluded due to needing specialized processing. expected_parameters = [ @@ -119,22 +119,15 @@ def get_run_args_from_yaml(path: str) -> dict: def config_loader(path: str) -> dict: """Loads a YAML file and returns the contents under the 'run_args' key. - Validates the existence and format of the YAML file and checks for the presence of - the 'run_args' as the only top level key. If any conditions are not met, - raises an - exception with a helpful message. - Args: path (str): Path to the YAML file. Returns: - dict: Contents under the 'run_args' key. + Content of the yaml (dict) Raises: FileNotFoundError: If the file at 'path' does not exist. ValueError: If the file is not a valid YAML. - KeyError: If 'run_args' key is missing. - KeyError: If 'run_args' is not the only top level key """ try: with open(path) as file: # noqa: PTH123 @@ -156,10 +149,13 @@ def extract_leaf_keys(d: dict, special_keys: dict | None = None) -> tuple[dict, Special keys (e.g.'run_pipeline') are also extracted if present and their corresponding values (dict) at any level in the nested structure. - :param d: The dictionary to extract values from. - :param special_keys: A dictionary to store values of special keys. - :return: A tuple containing the leaf keys dictionary and the dictionary for - special keys. + Args: + d (dict): The dictionary to extract values from. + special_keys (dict|None): A dictionary to store values of special keys. + + Returns: + A tuple containing the leaf keys dictionary and the dictionary for + special keys. """ if special_keys is None: special_keys = { @@ -191,14 +187,11 @@ def handle_special_argument_cases(settings: dict, special_configs: dict) -> None configurations like 'pre_load_hooks' which need individual processing or function loading. - Parameters: - - settings (dict): The dictionary to be updated with processed configurations. - - special_configs (dict): A dictionary containing configuration keys and values + Args: + settings (dict): The dictionary to be updated with processed configurations. + special_configs (dict): A dictionary containing configuration keys and values that require special processing. - Returns: - - None: The function modifies 'settings' in place. - """ # process special configs process_run_pipeline(RUN_PIPELINE, special_configs, settings) @@ -261,13 +254,13 @@ def process_searcher(key: str, special_configs: dict, settings: dict) -> None: Checks if the key exists in special_configs. If found, it processes the value based on its type. Updates settings with the processed searcher. - Parameters: - key (str): Key to look up in special_configs. - special_configs (dict): Dictionary of special configurations. - settings (dict): Dictionary to update with the processed searcher. + Args: + key (str): Key to look up in special_configs. + special_configs (dict): Dictionary of special configurations. + settings (dict): Dictionary to update with the processed searcher. Raises: - TypeError: If the value for the key is neither a string, Path, nor a dictionary. + TypeError: If the value for the key is neither a string, Path, nor a dictionary. """ if special_configs.get(key) is not None: searcher = special_configs[key] @@ -294,13 +287,13 @@ def process_searcher(key: str, special_configs: dict, settings: dict) -> None: def process_run_pipeline(key: str, special_configs: dict, settings: dict) -> None: """Processes the run pipeline configuration and updates the settings dictionary. - Parameters: - key (str): Key to look up in special_configs. - special_configs (dict): Dictionary of special configurations. - settings (dict): Dictionary to update with the processed function. + Args: + key (str): Key to look up in special_configs. + special_configs (dict): Dictionary of special configurations. + settings (dict): Dictionary to update with the processed function. Raises: - KeyError: If required keys ('path' and 'name') are missing in the config. + KeyError: If required keys ('path' and 'name') are missing in the config. """ if special_configs.get(key) is not None: config = special_configs[key] @@ -403,7 +396,7 @@ def check_run_args(settings: dict) -> None: TypeError for type mismatches. Args: - settings (dict): NEPS configuration settings. + settings (dict): NePS configuration settings. Raises: TypeError: For mismatched setting value types. @@ -462,14 +455,14 @@ def check_essential_arguments( searcher: BaseOptimizer | None, run_args: str | None, ) -> None: - """Validates essential NEPS configuration arguments. + """Validates essential NePS configuration arguments. Ensures 'run_pipeline', 'root_directory', 'pipeline_space', and either 'max_cost_total' or 'max_evaluation_total' are provided for NePS execution. Raises ValueError with missing argument details. Additionally, checks 'searcher' is a BaseOptimizer if 'pipeline_space' is absent. - Parameters: + Args: run_pipeline: Function for the pipeline execution. root_directory (str): Directory path for data storage. pipeline_space: search space for this run. @@ -499,19 +492,19 @@ def check_essential_arguments( def check_double_reference( func: Callable, func_arguments: dict, yaml_arguments: dict -) -> Any: +) -> None: """Checks if no argument is defined both via function arguments and YAML. - Parameters: - - func (Callable): The function to check arguments against. - - func_arguments (Dict): A dictionary containing the provided arguments to the - function and their values. - - yaml_arguments (Dict): A dictionary containing the arguments provided via a YAML - file. + Args: + func (Callable): The function to check arguments against. + func_arguments (Dict): A dictionary containing the provided arguments to the + function and their values. + yaml_arguments (Dict): A dictionary containing the arguments provided via a YAML + file. Raises: - - ValueError: If any provided argument is defined both via function arguments and the - YAML file. + ValueError: If any provided argument is defined both via function arguments and + the YAML file. """ sig = inspect.signature(func)