Skip to content

Commit

Permalink
Add pyarrow to dependencies, fix incumbent logging frequency and smal…
Browse files Browse the repository at this point in the history
…l init bug

Also adjust the unit tests accordingly.
  • Loading branch information
Bronzila committed May 29, 2024
1 parent 1ac9284 commit 28a18af
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 22 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ dependencies = [
"dask>=2.27.0",
"distributed>=2.27.0",
"ConfigSpace>=0.4.16",
"pandas>=1.4.4"
"pandas>=1.4.4",
"pyarrow>=16.1.0"
]
classifiers = [
"Programming Language :: Python :: 3.8",
Expand Down
65 changes: 44 additions & 21 deletions tests/test_dehb.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,40 @@
import json
import time
import builtins
import io
import os
import typing

import ConfigSpace
import numpy as np
import pandas as pd
import pytest
from src.dehb.optimizers.dehb import DEHB


def patch_open(open_func, files):
def open_patched(path, mode="r", buffering=-1, encoding=None,
errors=None, newline=None, closefd=True,
opener=None):
if "w" in mode and not os.path.isfile(path):
files.append(path)
return open_func(path, mode=mode, buffering=buffering,
encoding=encoding, errors=errors,
newline=newline, closefd=closefd,
opener=opener)
return open_patched


@pytest.fixture(autouse=True)
def cleanup_files(monkeypatch):
"""This fixture automatically cleans up all files that have been written by the tests after
execution.
"""
files = []
monkeypatch.setattr(builtins, "open", patch_open(builtins.open, files))
monkeypatch.setattr(io, "open", patch_open(io.open, files))
yield
for file in files:
os.remove(file)

def create_toy_searchspace():
"""Creates a toy searchspace with a single hyperparameter.
Expand Down Expand Up @@ -289,7 +316,7 @@ def test_state_before_eval(self):
de_params.pop("output_path")
for key in de_params:
assert de_params[key] == dehb.de_params[key]
def test_freq_incumbent(self):
def test_freq_step(self):
"""Verifies, that the save_freq 'step' saves the state at the right times."""
cs = create_toy_searchspace()
dehb = create_toy_optimizer(configspace=cs, min_fidelity=3, max_fidelity=27, eta=3,
Expand All @@ -300,25 +327,23 @@ def test_freq_incumbent(self):
result = objective_function(job_info["config"], job_info["fidelity"])
dehb.tell(job_info, result)

# Now state should be saved --> load config_repo
config_repo_path = dehb.output_path / "config_repository.json"
with config_repo_path.open() as f:
config_repo_list = json.load(f)
# Now state should be saved --> load history
history_path = dehb.output_path / "history.parquet.gzip"
history = pd.read_parquet(history_path)

assert len(config_repo_list) == len(dehb.config_repository.configs)
assert len(history) == len(dehb.history)

# Second ask/tell
job_info = dehb.ask()
# Result should be worse than first result so that it can not trigger "incumbent" save_freq
result["fitness"] += 10
dehb.tell(job_info, result)

# Now state should be saved --> load config_repo
config_repo_path = dehb.output_path / "config_repository.json"
with config_repo_path.open() as f:
config_repo_list = json.load(f)
# Now state should be saved --> load history
history_path = dehb.output_path / "history.parquet.gzip"
history = pd.read_parquet(history_path)

assert len(config_repo_list) == len(dehb.config_repository.configs)
assert len(history) == len(dehb.history)

def test_freq_incumbent(self):
"""Verifies, that the save_freq 'incumbent' saves the state at the right times."""
Expand All @@ -332,11 +357,10 @@ def test_freq_incumbent(self):
dehb.tell(job_info, result)

# Now state should be saved, because first config is always incumbent --> load config_repo
config_repo_path = dehb.output_path / "config_repository.json"
with config_repo_path.open() as f:
config_repo_list = json.load(f)
history_path = dehb.output_path / "history.parquet.gzip"
history = pd.read_parquet(history_path)

assert len(config_repo_list) == len(dehb.config_repository.configs)
assert len(history) == len(dehb.history)

# Second ask/tell
job_info = dehb.ask()
Expand All @@ -345,11 +369,10 @@ def test_freq_incumbent(self):
dehb.tell(job_info, result)

# State should not have been updated
config_repo_path = dehb.output_path / "config_repository.json"
with config_repo_path.open() as f:
config_repo_list = json.load(f)
history_path = dehb.output_path / "history.parquet.gzip"
history = pd.read_parquet(history_path)

assert len(config_repo_list) == len(dehb.config_repository.configs) - 1
assert len(history) == len(dehb.history) - 1

class TestRestart:
"""Class that bundles all tests regarding the restarting functionality of DEHB."""
Expand Down

0 comments on commit 28a18af

Please sign in to comment.