Skip to content

Commit

Permalink
Merge pull request #46 from automl/min_max
Browse files Browse the repository at this point in the history
Add error messages and unit tests for setups where min_budget >= max_budget
  • Loading branch information
Neeratyoy committed Aug 1, 2023
2 parents f8130b3 + 0c6b26a commit adac335
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 19 deletions.
34 changes: 23 additions & 11 deletions src/dehb/optimizers/dehb.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __init__(self, cs=None, f=None, dimensions=None, mutation_factor=None,
crossover_prob=None, strategy=None, min_budget=None,
max_budget=None, eta=None, min_clip=None, max_clip=None,
boundary_fix_type='random', max_age=np.inf, **kwargs):
# Miscellaneous
self._setup_logger(kwargs)

# Benchmark related variables
self.cs = cs
self.configspace = True if isinstance(self.cs, ConfigSpace.ConfigurationSpace) else False
Expand Down Expand Up @@ -59,7 +62,14 @@ def __init__(self, cs=None, f=None, dimensions=None, mutation_factor=None,
# Hyperband related variables
self.min_budget = min_budget
self.max_budget = max_budget
assert self.max_budget > self.min_budget, "only (Max Budget > Min Budget) supported!"
if self.max_budget <= self.min_budget:
self.logger.error("Only (Max Budget > Min Budget) is supported for DEHB.")
if self.max_budget == self.min_budget:
self.logger.error(
"If you have a fixed fidelity, " \
"you can instead run DE. For more information checkout: " \
"https://automl.github.io/DEHB/references/de")
raise AssertionError()
self.eta = eta
self.min_clip = min_clip
self.max_clip = max_clip
Expand All @@ -75,7 +85,18 @@ def __init__(self, cs=None, f=None, dimensions=None, mutation_factor=None,
-np.linspace(start=self.max_SH_iter - 1,
stop=0, num=self.max_SH_iter))

# Miscellaneous
# Updating DE parameter list
self.de_params.update({"output_path": self.output_path})

# Global trackers
self.population = None
self.fitness = None
self.inc_score = np.inf
self.inc_config = None
self.history = []

def _setup_logger(self, kwargs):
"""Sets up the logger."""
self.output_path = kwargs['output_path'] if 'output_path' in kwargs else './'
os.makedirs(self.output_path, exist_ok=True)
self.logger = logger
Expand All @@ -86,15 +107,6 @@ def __init__(self, cs=None, f=None, dimensions=None, mutation_factor=None,
**_logger_props
)
self.log_filename = "{}/dehb_{}.log".format(self.output_path, log_suffix)
# Updating DE parameter list
self.de_params.update({"output_path": self.output_path})

# Global trackers
self.population = None
self.fitness = None
self.inc_score = np.inf
self.inc_config = None
self.history = []

def reset(self):
self.inc_score = np.inf
Expand Down
30 changes: 22 additions & 8 deletions tests/test_dehb.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pytest
import time
import typing

import ConfigSpace
import numpy as np
import time
import pytest
from src.dehb.optimizers.dehb import DEHB


def create_toy_searchspace():
"""Creates a toy searchspace with a single hyperparameter.
Expand Down Expand Up @@ -66,8 +68,7 @@ class TestBudgetExhaustion():
evaluations and number of brackets to run.
"""
def test_runtime_exhaustion(self):
"""Test for runtime budget exhaustion.
"""
"""Test for runtime budget exhaustion."""
cs = create_toy_searchspace()
dehb = create_toy_optimizer(configspace=cs, min_budget=3, max_budget=27, eta=3,
objective_function=objective_function)
Expand All @@ -77,8 +78,7 @@ def test_runtime_exhaustion(self):
assert dehb._is_run_budget_exhausted(total_cost=1), "Run budget should be exhausted"

def test_fevals_exhaustion(self):
"""Test for function evaluations budget exhaustion.
"""
"""Test for function evaluations budget exhaustion."""
cs = create_toy_searchspace()
dehb = create_toy_optimizer(configspace=cs, min_budget=3, max_budget=27, eta=3,
objective_function=objective_function)
Expand All @@ -88,8 +88,7 @@ def test_fevals_exhaustion(self):
assert dehb._is_run_budget_exhausted(fevals=1), "Run budget should be exhausted"

def test_brackets_exhaustion(self):
"""Test for bracket budget exhaustion.
"""
"""Test for bracket budget exhaustion."""
cs = create_toy_searchspace()
dehb = create_toy_optimizer(configspace=cs, min_budget=3, max_budget=27, eta=3,
objective_function=objective_function)
Expand All @@ -98,3 +97,18 @@ def test_brackets_exhaustion(self):

assert dehb._is_run_budget_exhausted(brackets=1), "Run budget should be exhausted"

class TestInitialization:
"""Class that bundles all tests regarding the initialization of DEHB."""
def test_higher_min_budget(self):
"""Test that verifies, that DEHB breaks if min_budget > max_budget."""
cs = create_toy_searchspace()
with pytest.raises(AssertionError):
create_toy_optimizer(configspace=cs, min_budget=28, max_budget=27, eta=3,
objective_function=objective_function)

def test_equal_min_max_budget(self):
"""Test that verifies, that DEHB breaks if min_budget == max_budget."""
cs = create_toy_searchspace()
with pytest.raises(AssertionError):
create_toy_optimizer(configspace=cs, min_budget=27, max_budget=27, eta=3,
objective_function=objective_function)

0 comments on commit adac335

Please sign in to comment.