From 28a18af288e338e69c478f4024d79da031e88431 Mon Sep 17 00:00:00 2001 From: Janis Fix Date: Wed, 29 May 2024 11:11:27 +0200 Subject: [PATCH] Add pyarrow to dependencies, fix incumbent logging frequency and small init bug Also adjust the unit tests accordingly. --- pyproject.toml | 3 ++- tests/test_dehb.py | 65 +++++++++++++++++++++++++++++++--------------- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d538720..c24b39d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,8 @@ dependencies = [ "dask>=2.27.0", "distributed>=2.27.0", "ConfigSpace>=0.4.16", - "pandas>=1.4.4" + "pandas>=1.4.4", + "pyarrow>=16.1.0" ] classifiers = [ "Programming Language :: Python :: 3.8", diff --git a/tests/test_dehb.py b/tests/test_dehb.py index 707a1bf..2346948 100644 --- a/tests/test_dehb.py +++ b/tests/test_dehb.py @@ -1,13 +1,40 @@ -import json -import time +import builtins +import io +import os import typing import ConfigSpace import numpy as np +import pandas as pd import pytest from src.dehb.optimizers.dehb import DEHB +def patch_open(open_func, files): + def open_patched(path, mode="r", buffering=-1, encoding=None, + errors=None, newline=None, closefd=True, + opener=None): + if "w" in mode and not os.path.isfile(path): + files.append(path) + return open_func(path, mode=mode, buffering=buffering, + encoding=encoding, errors=errors, + newline=newline, closefd=closefd, + opener=opener) + return open_patched + + +@pytest.fixture(autouse=True) +def cleanup_files(monkeypatch): + """This fixture automatically cleans up all files that have been written by the tests after + execution. + """ + files = [] + monkeypatch.setattr(builtins, "open", patch_open(builtins.open, files)) + monkeypatch.setattr(io, "open", patch_open(io.open, files)) + yield + for file in files: + os.remove(file) + def create_toy_searchspace(): """Creates a toy searchspace with a single hyperparameter. @@ -289,7 +316,7 @@ def test_state_before_eval(self): de_params.pop("output_path") for key in de_params: assert de_params[key] == dehb.de_params[key] - def test_freq_incumbent(self): + def test_freq_step(self): """Verifies, that the save_freq 'step' saves the state at the right times.""" cs = create_toy_searchspace() dehb = create_toy_optimizer(configspace=cs, min_fidelity=3, max_fidelity=27, eta=3, @@ -300,12 +327,11 @@ def test_freq_incumbent(self): result = objective_function(job_info["config"], job_info["fidelity"]) dehb.tell(job_info, result) - # Now state should be saved --> load config_repo - config_repo_path = dehb.output_path / "config_repository.json" - with config_repo_path.open() as f: - config_repo_list = json.load(f) + # Now state should be saved --> load history + history_path = dehb.output_path / "history.parquet.gzip" + history = pd.read_parquet(history_path) - assert len(config_repo_list) == len(dehb.config_repository.configs) + assert len(history) == len(dehb.history) # Second ask/tell job_info = dehb.ask() @@ -313,12 +339,11 @@ def test_freq_incumbent(self): result["fitness"] += 10 dehb.tell(job_info, result) - # Now state should be saved --> load config_repo - config_repo_path = dehb.output_path / "config_repository.json" - with config_repo_path.open() as f: - config_repo_list = json.load(f) + # Now state should be saved --> load history + history_path = dehb.output_path / "history.parquet.gzip" + history = pd.read_parquet(history_path) - assert len(config_repo_list) == len(dehb.config_repository.configs) + assert len(history) == len(dehb.history) def test_freq_incumbent(self): """Verifies, that the save_freq 'incumbent' saves the state at the right times.""" @@ -332,11 +357,10 @@ def test_freq_incumbent(self): dehb.tell(job_info, result) # Now state should be saved, because first config is always incumbent --> load config_repo - config_repo_path = dehb.output_path / "config_repository.json" - with config_repo_path.open() as f: - config_repo_list = json.load(f) + history_path = dehb.output_path / "history.parquet.gzip" + history = pd.read_parquet(history_path) - assert len(config_repo_list) == len(dehb.config_repository.configs) + assert len(history) == len(dehb.history) # Second ask/tell job_info = dehb.ask() @@ -345,11 +369,10 @@ def test_freq_incumbent(self): dehb.tell(job_info, result) # State should not have been updated - config_repo_path = dehb.output_path / "config_repository.json" - with config_repo_path.open() as f: - config_repo_list = json.load(f) + history_path = dehb.output_path / "history.parquet.gzip" + history = pd.read_parquet(history_path) - assert len(config_repo_list) == len(dehb.config_repository.configs) - 1 + assert len(history) == len(dehb.history) - 1 class TestRestart: """Class that bundles all tests regarding the restarting functionality of DEHB."""