Skip to content

Commit

Permalink
refactor: modularize file state (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman authored Aug 2, 2024
1 parent 2ee606f commit 08f30ae
Show file tree
Hide file tree
Showing 63 changed files with 5,523 additions and 1,696 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,6 @@ jahs_bench_data/

# MacOS
*.DS_Store

# Yaml tests
path
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
files: '^src/.*\.py$'

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.0
rev: v1.11.1
hooks:
- id: mypy
files: |
Expand All @@ -42,7 +42,7 @@ repos:
- "--show-traceback"

- repo: https://github.com/python-jsonschema/check-jsonschema
rev: 0.28.2
rev: 0.29.1
hooks:
- id: check-github-workflows
files: '^github/workflows/.*\.ya?ml$'
Expand All @@ -51,7 +51,7 @@ repos:
files: '^\.github/dependabot\.ya?ml$'

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.4.2
rev: v0.5.5
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --no-cache]
Expand Down
64 changes: 34 additions & 30 deletions neps/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""API for the neps package.
"""
"""API for the neps package."""

from __future__ import annotations

import inspect
Expand All @@ -12,7 +12,7 @@
from neps.utils.run_args import Settings, Default

from neps.utils.common import instance_from_map
from neps.runtime import launch_runtime
from neps.runtime import _launch_runtime
from neps.optimizers import BaseOptimizer, SearcherMapping
from neps.search_spaces.parameter import Parameter
from neps.search_spaces.search_space import (
Expand All @@ -24,6 +24,8 @@
from neps.utils.common import get_searcher_data, get_value
from neps.optimizers.info import SearcherConfigs

logger = logging.getLogger(__name__)


def run(
run_pipeline: Callable | None = Default(None),
Expand Down Expand Up @@ -59,7 +61,8 @@ def run(
"asha",
"regularized_evolution",
]
| BaseOptimizer | Path
| BaseOptimizer
| Path
) = Default("default"),
**searcher_kwargs,
) -> None:
Expand Down Expand Up @@ -144,13 +147,11 @@ def run(
)
max_cost_total = searcher_kwargs["budget"]
del searcher_kwargs["budget"]

settings = Settings(locals(), run_args)
# TODO: check_essentials,

logger = logging.getLogger("neps")

# DO NOT use any neps arguments directly; instead, access them via the Settings class.

if settings.pre_load_hooks is None:
settings.pre_load_hooks = []

Expand All @@ -175,8 +176,9 @@ def run(
# TODO habe hier searcher kwargs gedroppt, sprich das merging muss davor statt
# finden
searcher_info["searcher_args"] = settings.searcher_kwargs
settings.searcher = settings.searcher(search_space,
**settings.searcher_kwargs)
settings.searcher = settings.searcher(
search_space, **settings.searcher_kwargs
)
else:
# Raise an error if searcher is not a subclass of BaseOptimizer
raise TypeError(
Expand All @@ -200,7 +202,6 @@ def run(
ignore_errors=settings.ignore_errors,
loss_value_on_error=settings.loss_value_on_error,
cost_value_on_error=settings.cost_value_on_error,
logger=logger,
searcher=settings.searcher,
**settings.searcher_kwargs,
)
Expand All @@ -220,23 +221,25 @@ def run(
)

if settings.task_id is not None:
settings.root_directory = Path(settings.root_directory) / (f"task_"
f"{settings.task_id}")
settings.root_directory = Path(settings.root_directory) / (
f"task_" f"{settings.task_id}"
)
if settings.development_stage_id is not None:
settings.root_directory = (Path(settings.root_directory) /
f"dev_{settings.development_stage_id}")
settings.root_directory = (
Path(settings.root_directory) / f"dev_{settings.development_stage_id}"
)

launch_runtime(
_launch_runtime(
evaluation_fn=settings.run_pipeline,
sampler=searcher_instance,
optimizer=searcher_instance,
optimizer_info=searcher_info,
optimization_dir=settings.root_directory,
max_cost_total=settings.max_cost_total,
optimization_dir=Path(settings.root_directory),
max_evaluations_total=settings.max_evaluations_total,
max_evaluations_per_run=settings.max_evaluations_per_run,
continue_until_max_evaluation_completed
=settings.continue_until_max_evaluation_completed,
logger=logger,
max_evaluations_for_worker=settings.max_evaluations_per_run,
continue_until_max_evaluation_completed=settings.continue_until_max_evaluation_completed,
loss_value_on_error=settings.loss_value_on_error,
cost_value_on_error=settings.cost_value_on_error,
ignore_errors=settings.ignore_errors,
overwrite_optimization_dir=settings.overwrite_working_directory,
pre_load_hooks=settings.pre_load_hooks,
Expand All @@ -260,7 +263,6 @@ def _run_args(
ignore_errors: bool = False,
loss_value_on_error: None | float = None,
cost_value_on_error: None | float = None,
logger=None,
searcher: (
Literal[
"default",
Expand Down Expand Up @@ -306,13 +308,17 @@ def _run_args(
raise TypeError(message) from e

# Load the information of the optimizer
if isinstance(searcher, (str, Path)) and searcher not in \
SearcherConfigs.get_searchers() and searcher != "default":
if (
isinstance(searcher, (str, Path))
and searcher not in SearcherConfigs.get_searchers()
and searcher != "default"
):
# The users have their own custom searcher provided via yaml.
logging.info("Preparing to run user created searcher")

searcher_config, file_name = get_searcher_data(searcher,
loading_custom_searcher=True)
searcher_config, file_name = get_searcher_data(
searcher, loading_custom_searcher=True
)
# name defined via key or the filename of the yaml
searcher_name = searcher_config.pop("name", file_name)
searcher_info["searcher_selection"] = "user-yaml"
Expand Down Expand Up @@ -351,21 +357,19 @@ def _run_args(
warnings.warn(
"The 'algorithm' argument is deprecated and will be removed in "
"future versions. Please use 'strategy' instead.",
DeprecationWarning
DeprecationWarning,
)
# Map the old 'algorithm' argument to 'strategy'
searcher_config['strategy'] = searcher_config.pop("algorithm")
searcher_config["strategy"] = searcher_config.pop("algorithm")

if "strategy" in searcher_config:
searcher_alg = searcher_config.pop("strategy")
else:
raise KeyError(f"Missing key strategy in searcher config:{searcher_config}")


logger.info(f"Running {searcher_name} as the searcher")
logger.info(f"Strategy: {searcher_alg}")


# Used to create the yaml holding information about the searcher.
# Also important for testing and debugging the api.
searcher_info["searcher_name"] = searcher_name
Expand Down
89 changes: 89 additions & 0 deletions neps/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Environment variable parsing for the state."""

from __future__ import annotations

import os
from typing import Callable, TypeVar

T = TypeVar("T")
V = TypeVar("V")


def get_env(key: str, parse: Callable[[str], T], default: V) -> T | V:
"""Get an environment variable or return a default value."""
if (e := os.environ.get(key)) is not None:
return parse(e)

return default


def is_nullable(e: str) -> bool:
"""Check if an environment variable is nullable."""
return e.lower() in ("none", "n", "null")


TRIAL_FILELOCK_POLL = get_env(
"NEPS_TRIAL_FILELOCK_POLL",
parse=float,
default=0.05,
)
TRIAL_FILELOCK_TIMEOUT = get_env(
"NEPS_TRIAL_FILELOCK_TIMEOUT",
parse=lambda e: None if is_nullable(e) else float(e),
default=None,
)

JOBQUEUE_FILELOCK_POLL = get_env(
"NEPS_JOBQUEUE_FILELOCK_POLL",
parse=float,
default=0.05,
)
JOBQUEUE_FILELOCK_TIMEOUT = get_env(
"NEPS_JOBQUEUE_FILELOCK_TIMEOUT",
parse=lambda e: None if is_nullable(e) else float(e),
default=None,
)

SEED_SNAPSHOT_FILELOCK_POLL = get_env(
"NEPS_SEED_SNAPSHOT_FILELOCK_POLL",
parse=float,
default=0.05,
)
SEED_SNAPSHOT_FILELOCK_TIMEOUT = get_env(
"NEPS_SEED_SNAPSHOT_FILELOCK_TIMEOUT",
parse=lambda e: None if is_nullable(e) else float(e),
default=None,
)

OPTIMIZER_INFO_FILELOCK_POLL = get_env(
"NEPS_OPTIMIZER_INFO_FILELOCK_POLL",
parse=float,
default=0.05,
)
OPTIMIZER_INFO_FILELOCK_TIMEOUT = get_env(
"NEPS_OPTIMIZER_INFO_FILELOCK_TIMEOUT",
parse=lambda e: None if is_nullable(e) else float(e),
default=None,
)

OPTIMIZER_STATE_FILELOCK_POLL = get_env(
"NEPS_OPTIMIZER_STATE_FILELOCK_POLL",
parse=float,
default=0.05,
)
OPTIMIZER_STATE_FILELOCK_TIMEOUT = get_env(
"NEPS_OPTIMIZER_STATE_FILELOCK_TIMEOUT",
parse=lambda e: None if is_nullable(e) else float(e),
default=None,
)

GLOBAL_ERR_FILELOCK_POLL = get_env(
"NEPS_GLOBAL_ERR_FILELOCK_POLL",
parse=float,
default=0.05,
)
GLOBAL_ERR_FILELOCK_TIMEOUT = get_env(
"NEPS_GLOBAL_ERR_FILELOCK_TIMEOUT",
parse=lambda e: None if is_nullable(e) else float(e),
default=None,
)
47 changes: 47 additions & 0 deletions neps/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Exceptions for NePS that don't belong in a specific module."""

from __future__ import annotations


class NePSError(Exception):
"""Base class for all NePS exceptions.
This allows an easier way to catch all NePS exceptions
if we inherit all exceptions from this class.
"""


class VersionMismatchError(NePSError):
"""Raised when the version of a resource does not match the expected version."""


class VersionedResourceAlreadyExistsError(NePSError):
"""Raised when a version already exists when trying to create a new versioned
data.
"""


class VersionedResourceRemovedError(NePSError):
"""Raised when a version already exists when trying to create a new versioned
data.
"""


class VersionedResourceDoesNotExistsError(NePSError):
"""Raised when a versioned resource does not exist at a location."""


class LockFailedError(NePSError):
"""Raised when a lock cannot be acquired."""


class TrialAlreadyExistsError(VersionedResourceAlreadyExistsError):
"""Raised when a trial already exists in the store."""


class TrialNotFoundError(VersionedResourceDoesNotExistsError):
"""Raised when a trial already exists in the store."""


class WorkerFailedToGetPendingTrialsError(NePSError):
"""Raised when a worker failed to get pending trials."""
5 changes: 3 additions & 2 deletions neps/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import Callable
from typing import Callable, Mapping

from .base_optimizer import BaseOptimizer
from .bayesian_optimization.cost_cooling import CostCooling
Expand All @@ -26,7 +26,8 @@
from .random_search.optimizer import RandomSearch
from .regularized_evolution.optimizer import RegularizedEvolution

SearcherMapping: dict[str, Callable] = {
# TODO: Rename Searcher to Optimizer...
SearcherMapping: Mapping[str, Callable[..., BaseOptimizer]] = {
"bayesian_optimization": BayesianOptimization,
"pibo": partial(BayesianOptimization, disable_priors=False),
"cost_cooling_bayesian_optimization": CostCooling,
Expand Down
Loading

0 comments on commit 08f30ae

Please sign in to comment.