Skip to content

Commit

Permalink
Merge pull request #78 from automl/feat/tell_errors
Browse files Browse the repository at this point in the history
Add errors for #tells > #asks
  • Loading branch information
Bronzila authored Apr 1, 2024
2 parents 118e896 + e65d71c commit 9408380
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/dehb/optimizers/dehb.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def __init__(self, cs=None, f=None, dimensions=None, mutation_factor=0.5,
self.traj = []
self.runtime = []
self.history = []
self._ask_counter = 0
self._tell_counter = 0
self.start = None
if save_freq not in ["incumbent", "step", "end"] and save_freq is not None:
self.logger.warning(f"Save frequency {save_freq} unknown. Resorting to using 'end'.")
Expand Down Expand Up @@ -355,6 +357,9 @@ def reset(self):
self.traj = []
self.runtime = []
self.history = []
self._ask_counter = 0
self._tell_counter = 0
self.config_repository.reset()
self._get_pop_sizes()
self._init_subpop()
self.available_gpus = None
Expand Down Expand Up @@ -651,9 +656,11 @@ def ask(self, n_configs: int=1):
jobs = []
if n_configs == 1:
jobs = self._get_next_job()
self._ask_counter += 1
else:
for _ in range(n_configs):
jobs.append(self._get_next_job())
self._ask_counter += 1
# Save random state after ask
self.random_state = self.rng.bit_generator.state
if self.use_configspace:
Expand Down Expand Up @@ -980,6 +987,10 @@ def tell(self, job_info: dict, result: dict, replay: bool=False):
# Replace job_info with container to make sure all fields are given
job_info = job_info_container

if self._tell_counter >= self._ask_counter:
raise NotImplementedError("Called tell() more often than ask(). \
Warmstarting with tell is not supported. ")
self._tell_counter += 1
# Update bracket information
fitness, cost = result["fitness"], result["cost"]
info = result["info"] if "info" in result else dict()
Expand Down
15 changes: 15 additions & 0 deletions tests/test_dehb.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,21 @@ def test_ask_twice_different(self):
job_info_b = dehb.ask()
assert job_info_a != job_info_b

def test_tell_twice(self):
"""Verifies, that tell should not be allowed to be called more often than ask."""
cs = create_toy_searchspace()
dehb = create_toy_optimizer(configspace=cs, min_fidelity=3, max_fidelity=27, eta=3,
objective_function=objective_function)
# Get single job info
job_info = dehb.ask()
res = objective_function(job_info["config"], job_info["fidelity"])

# Tell twice, first should work
dehb.tell(job_info, res)
# Second tell should raise an error
with pytest.raises(NotImplementedError):
dehb.tell(job_info, res)

def test_tell_successful(self):
"""Verifies, that tell successfully saves results."""
cs = create_toy_searchspace()
Expand Down

0 comments on commit 9408380

Please sign in to comment.