Skip to content

Commit

Permalink
Fix empty dict parquet issues and typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Bronzila committed May 29, 2024
1 parent 66d71c3 commit 1ac9284
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/dehb/optimizers/dehb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import deepcopy
from pathlib import Path
from threading import Timer
from typing import Union
from typing import List, Tuple, Union

import ConfigSpace
import numpy as np
Expand Down Expand Up @@ -144,7 +144,7 @@ def reset(self, *, reset_seeds: bool = True):
def _init_population(self):
raise NotImplementedError("Redefine!")

def _get_next_iteration(self, iteration: int) -> tuple[np.array, np.array]:
def _get_next_iteration(self, iteration: int) -> Tuple[np.array, np.array]:
"""Computes the Successive Halving spacing.
Given the iteration index, computes the fidelity spacing to be used and
Expand All @@ -171,7 +171,7 @@ def _get_next_iteration(self, iteration: int) -> tuple[np.array, np.array]:

return ns, fidelities

def get_incumbents(self) -> tuple[Union[dict, ConfigSpace.Configuration], float]:
def get_incumbents(self) -> Tuple[Union[dict, ConfigSpace.Configuration], float]:
"""Retrieve current incumbent configuration and score.
Returns:
Expand Down Expand Up @@ -306,7 +306,7 @@ def _f_objective(self, job_info):
run_info.update({"device_id": device_id})
return run_info

def _create_cuda_visible_devices(self, available_gpus: list[int], start_id: int) -> str:
def _create_cuda_visible_devices(self, available_gpus: List[int], start_id: int) -> str:
"""Generates a string to set the CUDA_VISIBLE_DEVICES environment variable.
Given a list of available GPU device IDs and a preferred ID (start_id), the environment
Expand Down Expand Up @@ -675,7 +675,7 @@ def _get_next_job(self):
break
return job_info

def ask(self, n_configs: int=1) -> Union[dict, list[dict]]:
def ask(self, n_configs: int=1) -> Union[dict, List[dict]]:
"""Get the next configuration to run from the optimizer.
The retrieved configuration can then be evaluated by the user.
Expand Down Expand Up @@ -846,6 +846,10 @@ def _save_history(self, name="history.parquet.gzip"):
history_path = self.output_path / name
history_df = pd.DataFrame(self.history, columns=["config_id", "config", "fitness",
"cost", "fidelity", "info"])
# Check if the 'info' column is empty or contains only None values
if history_df["info"].apply(lambda x: (isinstance(x, dict) and len(x) == 0)).all():
# Drop the 'info' column
history_df = history_df.drop(columns=["info"])
history_df.to_parquet(history_path, compression="gzip")
except Exception as e:
self.logger.warning(f"History not saved: {e!r}")
Expand Down Expand Up @@ -925,7 +929,7 @@ def _load_checkpoint(self, run_dir: str):
result = {
"fitness": row["fitness"],
"cost": row["cost"],
"info": row["info"],
"info": row.get("info", {}),
}

self.tell(job_info, result, replay=True)
Expand Down Expand Up @@ -984,7 +988,7 @@ def tell(self, job_info: dict, result: dict, replay: bool=False) -> None:
self._tell_counter += 1
# Update bracket information
fitness, cost = float(result["fitness"]), float(result["cost"])
info = result["info"] if "info" in result else dict()
info = result["info"] if "info" in result else {}
fidelity, parent_id = job_info["fidelity"], job_info["parent_id"]
config, config_id = job_info["config"], job_info["config_id"]
bracket_id = job_info["bracket_id"]
Expand Down Expand Up @@ -1025,7 +1029,7 @@ def tell(self, job_info: dict, result: dict, replay: bool=False) -> None:

@logger.catch
def run(self, fevals=None, brackets=None, total_cost=None, single_node_with_gpus=False,
verbose=False, debug=False, **kwargs) -> tuple[np.array, np.array, np.array]:
verbose=False, debug=False, **kwargs) -> Tuple[np.array, np.array, np.array]:
"""Main interface to run optimization by DEHB.
This function waits on workers and if a worker is free, asks for a configuration and a
Expand Down

0 comments on commit 1ac9284

Please sign in to comment.