Skip to content

Commit

Permalink
fix double logging and move post_eval_hook to runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
DaStoll committed Jun 29, 2024
1 parent 3a28758 commit 7397f4c
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 111 deletions.
78 changes: 2 additions & 76 deletions neps/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from neps.utils.common import instance_from_map
from neps.runtime import launch_runtime
from neps.optimizers import BaseOptimizer, SearcherMapping
from neps.plot.tensorboard_eval import tblogger
from neps.search_spaces.parameter import Parameter
from neps.search_spaces.search_space import (
SearchSpace,
Expand All @@ -25,81 +24,9 @@
)
from neps.status.status import post_run_csv
from neps.utils.common import get_searcher_data, get_value
from neps.utils.data_loading import _get_loss
from neps.optimizers.info import SearcherConfigs


def _post_evaluation_hook_function(
_loss_value_on_error: None | float, _ignore_errors: bool
):
def _post_evaluation_hook(
config,
config_id,
config_working_directory,
result,
logger,
loss_value_on_error=_loss_value_on_error,
ignore_errors=_ignore_errors,
):
working_directory = Path(config_working_directory, "../../")
loss = _get_loss(result, loss_value_on_error, ignore_errors=ignore_errors)

# 1. Write all configs and losses
all_configs_losses = Path(working_directory, "all_losses_and_configs.txt")

def write_loss_and_config(file_handle, loss_, config_id_, config_):
file_handle.write(f"Loss: {loss_}\n")
file_handle.write(f"Config ID: {config_id_}\n")
file_handle.write(f"Config: {config_}\n")
file_handle.write(79 * "-" + "\n")

with all_configs_losses.open("a", encoding="utf-8") as f:
write_loss_and_config(f, loss, config_id, config)

# no need to handle best loss cases if an error occurred
if result == "error":
return

# The "best" loss exists only in the pareto sense for multi-objective
is_multi_objective = isinstance(loss, dict)
if is_multi_objective:
logger.info(f"Finished evaluating config {config_id}")
return

# 2. Write best losses/configs
best_loss_trajectory_file = Path(working_directory, "best_loss_trajectory.txt")
best_loss_config_trajectory_file = Path(
working_directory, "best_loss_with_config_trajectory.txt"
)

if not best_loss_trajectory_file.exists():
is_new_best = result != "error"
else:
best_loss_trajectory = best_loss_trajectory_file.read_text(encoding="utf-8")
best_loss_trajectory = list(best_loss_trajectory.rstrip("\n").split("\n"))
best_loss = best_loss_trajectory[-1]
is_new_best = float(best_loss) > loss

if is_new_best:
with best_loss_trajectory_file.open("a", encoding="utf-8") as f:
f.write(f"{loss}\n")

with best_loss_config_trajectory_file.open("a", encoding="utf-8") as f:
write_loss_and_config(f, loss, config_id, config)

logger.info(
f"Finished evaluating config {config_id}"
f" -- new best with loss {float(loss) :.6f}"
)

else:
logger.info(f"Finished evaluating config {config_id}")

tblogger.end_of_config()

return _post_evaluation_hook


def run(
run_pipeline: Callable | None = None,
root_directory: str | Path | None = None,
Expand Down Expand Up @@ -342,9 +269,8 @@ def run(
max_evaluations_per_run=max_evaluations_per_run,
continue_until_max_evaluation_completed=continue_until_max_evaluation_completed,
logger=logger,
post_evaluation_hook=_post_evaluation_hook_function(
loss_value_on_error, ignore_errors
),
loss_value_on_error=loss_value_on_error,
ignore_errors=ignore_errors,
overwrite_optimization_dir=overwrite_working_directory,
pre_load_hooks=pre_load_hooks,
)
Expand Down
112 changes: 91 additions & 21 deletions neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
from neps.utils.files import deserialize, empty_file, serialize
from neps.utils.types import (
ERROR,
POST_EVAL_HOOK_SIGNATURE,
ConfigID,
ConfigResult,
RawConfig,
Expand Down Expand Up @@ -752,6 +751,75 @@ def _sample_trial_from_optimizer(
)


def _post_evaluation_hook( # type: ignore
trial: Trial,
result: ERROR | dict[str, Any],
logger: logging.Logger,
loss_value_on_error: float | None,
ignore_errors,
) -> None:
# We import here to avoid circular imports
from neps.plot.tensorboard_eval import tblogger
from neps.utils.data_loading import _get_loss

working_directory = Path(trial.pipeline_dir, "../../")
loss = _get_loss(result, loss_value_on_error, ignore_errors=ignore_errors)

# 1. Write all configs and losses
all_configs_losses = Path(working_directory, "all_losses_and_configs.txt")

def write_loss_and_config(file_handle, loss_, config_id_, config_): # type: ignore
file_handle.write(f"Loss: {loss_}\n")
file_handle.write(f"Config ID: {config_id_}\n")
file_handle.write(f"Config: {config_}\n")
file_handle.write(79 * "-" + "\n")

with all_configs_losses.open("a", encoding="utf-8") as f:
write_loss_and_config(f, loss, trial.id, trial.config)

# no need to handle best loss cases if an error occurred
if result == "error":
return

# The "best" loss exists only in the pareto sense for multi-objective
is_multi_objective = isinstance(loss, dict)
if is_multi_objective:
logger.info(f"Finished evaluating config {trial.id}")
return

# 2. Write best losses/configs
best_loss_trajectory_file = Path(working_directory, "best_loss_trajectory.txt")
best_loss_config_trajectory_file = Path(
working_directory, "best_loss_with_config_trajectory.txt"
)

if not best_loss_trajectory_file.exists():
is_new_best = result != "error"
else:
best_loss_trajectory: str | list[str]
best_loss_trajectory = best_loss_trajectory_file.read_text(encoding="utf-8")
best_loss_trajectory = list(best_loss_trajectory.rstrip("\n").split("\n"))
best_loss = best_loss_trajectory[-1]
is_new_best = float(best_loss) > loss # type: ignore

if is_new_best:
with best_loss_trajectory_file.open("a", encoding="utf-8") as f:
f.write(f"{loss}\n")

with best_loss_config_trajectory_file.open("a", encoding="utf-8") as f:
write_loss_and_config(f, loss, trial.id, trial.config)

logger.info(
f"Finished evaluating config {trial.id}"
f" -- new best with loss {float(loss) :.6f}"
)

else:
logger.info(f"Finished evaluating config {trial.id}")

tblogger.end_of_config()


def launch_runtime( # noqa: PLR0913, C901, PLR0915
*,
evaluation_fn: Callable[..., float | Mapping[str, Any]],
Expand All @@ -762,7 +830,8 @@ def launch_runtime( # noqa: PLR0913, C901, PLR0915
max_evaluations_per_run: int | None = None,
continue_until_max_evaluation_completed: bool = False,
logger: logging.Logger | None = None,
post_evaluation_hook: POST_EVAL_HOOK_SIGNATURE | None = None,
ignore_errors: bool = False,
loss_value_on_error: None | float = None,
overwrite_optimization_dir: bool = False,
pre_load_hooks: Iterable[Callable[[BaseOptimizer], BaseOptimizer]] | None = None,
) -> None:
Expand All @@ -781,7 +850,10 @@ def launch_runtime( # noqa: PLR0913, C901, PLR0915
continue_until_max_evaluation_completed: Whether to continue until the maximum
evaluations are completed.
logger: The logger to use.
post_evaluation_hook: A hook to run after the evaluation.
loss_value_on_error: Setting this and cost_value_on_error to any float will
supress any error and will use given loss value instead. default: None
ignore_errors: Ignore hyperparameter settings that threw an error and do not raise
an error. Error configs still count towards max_evaluations_total.
overwrite_optimization_dir: Whether to overwrite the optimization directory.
pre_load_hooks: Hooks to run before loading the results.
"""
Expand Down Expand Up @@ -950,23 +1022,21 @@ def launch_runtime( # noqa: PLR0913, C901, PLR0915
sampler.used_budget += eval_cost

_result: ERROR | dict[str, Any]
if post_evaluation_hook is not None:
report = trial.report
if isinstance(report, ErrorReport):
_result = "error"
elif isinstance(report, SuccessReport):
_result = dict(report.results)
else:
_type = type(report)
raise TypeError(f"Unknown result type '{_type}' for report: {report}")

post_evaluation_hook(
trial.config,
trial.id,
trial.pipeline_dir,
_result,
logger,
)
report = trial.report
if isinstance(report, ErrorReport):
_result = "error"
elif isinstance(report, SuccessReport):
_result = dict(report.results)
else:
_type = type(report)
raise TypeError(f"Unknown result type '{_type}' for report: {report}")

_post_evaluation_hook(
trial,
_result,
logger,
loss_value_on_error,
ignore_errors,
)

evaluations_in_this_run += 1
logger.info(f"Finished evaluating config {trial.id}")
15 changes: 1 addition & 14 deletions neps/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

from __future__ import annotations

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Mapping, Union
from typing import TYPE_CHECKING, Any, Dict, Literal, Mapping, Union
from typing_extensions import TypeAlias

import numpy as np
Expand Down Expand Up @@ -39,17 +37,6 @@ def __repr__(self) -> str:
NotSet = _NotSet()


POST_EVAL_HOOK_SIGNATURE: TypeAlias = Callable[
[
Mapping[str, Any],
str,
Path,
Union[Dict[str, Any], ERROR],
logging.Logger,
],
None,
]

f64 = np.float64
i64 = np.int64

Expand Down

0 comments on commit 7397f4c

Please sign in to comment.