Skip to content

Commit

Permalink
clean up code and docsctrings
Browse files Browse the repository at this point in the history
  • Loading branch information
danrgll committed Jun 18, 2024
1 parent 736b58a commit 285b993
Showing 1 changed file with 37 additions and 44 deletions.
81 changes: 37 additions & 44 deletions neps/utils/run_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
import sys
from pathlib import Path
from typing import Any, Callable
from typing import Callable

import yaml

Expand Down Expand Up @@ -45,13 +45,13 @@ def get_run_args_from_yaml(path: str) -> dict:
"""Load and validate NEPS run arguments from a specified YAML configuration file
provided via run_args.
This function reads a YAML file, extracts the arguments required by NEPS,
This function reads a YAML file, extracts the arguments required by NePS,
validates these arguments, and then returns them in a dictionary. It checks for the
presence and validity of expected parameters, and distinctively handles more complex
configurations, specifically those that are dictionaries(e.g. pipeline_space) or
objects(e.g. run_pipeline) requiring loading.
Parameters:
Args:
path (str): The file path to the YAML configuration file.
Returns:
Expand All @@ -66,7 +66,7 @@ def get_run_args_from_yaml(path: str) -> dict:
# Initialize an empty dictionary to hold the extracted settings
settings = {}

# List allowed NEPS run arguments with simple types (e.g., string, int). Parameters
# List allowed NePS run arguments with simple types (e.g., string, int). Parameters
# like 'run_pipeline', 'preload_hooks', 'pipeline_space',
# and 'searcher' are excluded due to needing specialized processing.
expected_parameters = [
Expand Down Expand Up @@ -119,22 +119,15 @@ def get_run_args_from_yaml(path: str) -> dict:
def config_loader(path: str) -> dict:
"""Loads a YAML file and returns the contents under the 'run_args' key.
Validates the existence and format of the YAML file and checks for the presence of
the 'run_args' as the only top level key. If any conditions are not met,
raises an
exception with a helpful message.
Args:
path (str): Path to the YAML file.
Returns:
dict: Contents under the 'run_args' key.
Content of the yaml (dict)
Raises:
FileNotFoundError: If the file at 'path' does not exist.
ValueError: If the file is not a valid YAML.
KeyError: If 'run_args' key is missing.
KeyError: If 'run_args' is not the only top level key
"""
try:
with open(path) as file: # noqa: PTH123
Expand All @@ -156,10 +149,13 @@ def extract_leaf_keys(d: dict, special_keys: dict | None = None) -> tuple[dict,
Special keys (e.g.'run_pipeline') are also extracted if present
and their corresponding values (dict) at any level in the nested structure.
:param d: The dictionary to extract values from.
:param special_keys: A dictionary to store values of special keys.
:return: A tuple containing the leaf keys dictionary and the dictionary for
special keys.
Args:
d (dict): The dictionary to extract values from.
special_keys (dict|None): A dictionary to store values of special keys.
Returns:
A tuple containing the leaf keys dictionary and the dictionary for
special keys.
"""
if special_keys is None:
special_keys = {
Expand Down Expand Up @@ -191,14 +187,11 @@ def handle_special_argument_cases(settings: dict, special_configs: dict) -> None
configurations like 'pre_load_hooks' which need individual processing or function
loading.
Parameters:
- settings (dict): The dictionary to be updated with processed configurations.
- special_configs (dict): A dictionary containing configuration keys and values
Args:
settings (dict): The dictionary to be updated with processed configurations.
special_configs (dict): A dictionary containing configuration keys and values
that require special processing.
Returns:
- None: The function modifies 'settings' in place.
"""
# process special configs
process_run_pipeline(RUN_PIPELINE, special_configs, settings)
Expand Down Expand Up @@ -261,13 +254,13 @@ def process_searcher(key: str, special_configs: dict, settings: dict) -> None:
Checks if the key exists in special_configs. If found, it processes the
value based on its type. Updates settings with the processed searcher.
Parameters:
key (str): Key to look up in special_configs.
special_configs (dict): Dictionary of special configurations.
settings (dict): Dictionary to update with the processed searcher.
Args:
key (str): Key to look up in special_configs.
special_configs (dict): Dictionary of special configurations.
settings (dict): Dictionary to update with the processed searcher.
Raises:
TypeError: If the value for the key is neither a string, Path, nor a dictionary.
TypeError: If the value for the key is neither a string, Path, nor a dictionary.
"""
if special_configs.get(key) is not None:
searcher = special_configs[key]
Expand All @@ -294,13 +287,13 @@ def process_searcher(key: str, special_configs: dict, settings: dict) -> None:
def process_run_pipeline(key: str, special_configs: dict, settings: dict) -> None:
"""Processes the run pipeline configuration and updates the settings dictionary.
Parameters:
key (str): Key to look up in special_configs.
special_configs (dict): Dictionary of special configurations.
settings (dict): Dictionary to update with the processed function.
Args:
key (str): Key to look up in special_configs.
special_configs (dict): Dictionary of special configurations.
settings (dict): Dictionary to update with the processed function.
Raises:
KeyError: If required keys ('path' and 'name') are missing in the config.
KeyError: If required keys ('path' and 'name') are missing in the config.
"""
if special_configs.get(key) is not None:
config = special_configs[key]
Expand Down Expand Up @@ -403,7 +396,7 @@ def check_run_args(settings: dict) -> None:
TypeError for type mismatches.
Args:
settings (dict): NEPS configuration settings.
settings (dict): NePS configuration settings.
Raises:
TypeError: For mismatched setting value types.
Expand Down Expand Up @@ -462,14 +455,14 @@ def check_essential_arguments(
searcher: BaseOptimizer | None,
run_args: str | None,
) -> None:
"""Validates essential NEPS configuration arguments.
"""Validates essential NePS configuration arguments.
Ensures 'run_pipeline', 'root_directory', 'pipeline_space', and either
'max_cost_total' or 'max_evaluation_total' are provided for NePS execution.
Raises ValueError with missing argument details. Additionally, checks 'searcher'
is a BaseOptimizer if 'pipeline_space' is absent.
Parameters:
Args:
run_pipeline: Function for the pipeline execution.
root_directory (str): Directory path for data storage.
pipeline_space: search space for this run.
Expand Down Expand Up @@ -499,19 +492,19 @@ def check_essential_arguments(

def check_double_reference(
func: Callable, func_arguments: dict, yaml_arguments: dict
) -> Any:
) -> None:
"""Checks if no argument is defined both via function arguments and YAML.
Parameters:
- func (Callable): The function to check arguments against.
- func_arguments (Dict): A dictionary containing the provided arguments to the
function and their values.
- yaml_arguments (Dict): A dictionary containing the arguments provided via a YAML
file.
Args:
func (Callable): The function to check arguments against.
func_arguments (Dict): A dictionary containing the provided arguments to the
function and their values.
yaml_arguments (Dict): A dictionary containing the arguments provided via a YAML
file.
Raises:
- ValueError: If any provided argument is defined both via function arguments and the
YAML file.
ValueError: If any provided argument is defined both via function arguments and
the YAML file.
"""
sig = inspect.signature(func)

Expand Down

0 comments on commit 285b993

Please sign in to comment.