Skip to content

Commit

Permalink
refactor: observed_configs -> max_budget_configs
Browse files Browse the repository at this point in the history
  • Loading branch information
karibbov committed Apr 10, 2024
1 parent eb46e35 commit e824c3f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 32 deletions.
12 changes: 6 additions & 6 deletions neps/optimizers/multi_fidelity/hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _update_sh_bracket_state(self) -> None:
# for the current SH bracket in HB
# TODO: can we avoid copying full observation history
bracket = self.sh_brackets[self.current_sh_bracket] # type: ignore
bracket.observed_configs = self.observed_configs.copy()
# bracket.max_budget_configs = self.max_budget_configs.copy()
# TODO: Do we NEED to copy here instead?
bracket.MFobserved_configs = self.MFobserved_configs

Expand Down Expand Up @@ -170,22 +170,22 @@ def clear_old_brackets(self):
base_rung_sizes = [] # sorted(self.config_map.values(), reverse=True)
for bracket in self.sh_brackets.values():
base_rung_sizes.append(sorted(bracket.config_map.values(), reverse=True)[0])
while end <= len(self.observed_configs):
while end <= len(self.max_budget_configs):
# subsetting only this SH bracket from the history
sh_bracket = self.sh_brackets[self.current_sh_bracket]
sh_bracket.clean_rung_information()
# for the SH bracket in start-end, calculate total SH budget used, from the
# correct SH bracket object to make the right budget calculations
# pylint: disable=protected-access
bracket_budget_used = sh_bracket._calc_budget_used_in_bracket(
deepcopy(self.observed_configs.rung.values[start:end])
deepcopy(self.max_budget_configs.rung.values[start:end])
)
# if budget used is less than the total SH budget then still an active bracket
current_bracket_full_budget = sum(sh_bracket.full_rung_trace)
if bracket_budget_used < current_bracket_full_budget:
# updating rung information of the current bracket
# pylint: disable=protected-access
sh_bracket._get_rungs_state(self.observed_configs.iloc[start:end])
sh_bracket._get_rungs_state(self.max_budget_configs.iloc[start:end])
# extra call to use the updated rung member info to find promotions
# SyncPromotion signals a wait if a rung is full but with
# incomplete/pending evaluations, signals to starts a new SH bracket
Expand All @@ -210,7 +210,7 @@ def clear_old_brackets(self):
# updates rung info with the latest active, incomplete bracket
sh_bracket = self.sh_brackets[self.current_sh_bracket]
# pylint: disable=protected-access
sh_bracket._get_rungs_state(self.observed_configs.iloc[start:end])
sh_bracket._get_rungs_state(self.max_budget_configs.iloc[start:end])
sh_bracket._handle_promotions()
# self._handle_promotion() need not be called as it is called by load_results()

Expand Down Expand Up @@ -380,7 +380,7 @@ def _update_sh_bracket_state(self) -> None:
config_map=bracket.config_map,
)
bracket.rung_promotions = bracket.promotion_policy.retrieve_promotions()
bracket.observed_configs = self.observed_configs.copy()
bracket.max_budget_configs = self.max_budget_configs.copy()

def _get_bracket_to_run(self):
"""Samples the ASHA bracket to run.
Expand Down
45 changes: 31 additions & 14 deletions neps/optimizers/multi_fidelity/successive_halving.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Literal

import numpy as np
import pandas as pd

from ...metahyper import ConfigResult
from ...search_spaces.hyperparameters.categorical import (
Expand Down Expand Up @@ -143,7 +142,9 @@ def __init__(
# TODO: replace with MFobserved_configs
# stores the observations made and the corresponding fidelity explored
# crucial data structure used for determining promotion candidates
self.observed_configs = pd.DataFrame([], columns=("config", "rung", "perf"))
self.__max_observed_configs = None
self.history_length = 0
# self.max_budget_configs = pd.DataFrame([], columns=("config", "rung", "perf"))
# stores which configs occupy each rung at any time
self.rung_members: dict = dict() # stores config IDs per rung
self.rung_members_performance: dict = dict() # performances recorded per rung
Expand All @@ -164,6 +165,22 @@ def __init__(
self._enhance_priors()
self.rung_histories = None

@property
def max_budget_configs(self):
"""
Make this property dynamically dependent on self.MFobserved_configs. So the state
of the algo only depends on self.MFobserved_configs.
"""
if self.__max_observed_configs is None or self.history_length != len(
self.MFobserved_configs.df
):
self.__max_observed_configs = self.MFobserved_configs.copy_df(
df=self.MFobserved_configs.reduce_to_max_seen_budgets()
)
self.history_length = len(self.MFobserved_configs.df)

return self.__max_observed_configs

@classmethod
def _get_rung_trace(cls, rung_map: dict, config_map: dict) -> list[int]:
"""Lists the rung IDs in sequence of the flattened SH tree."""
Expand All @@ -178,8 +195,8 @@ def get_incumbent_score(self):

# TODO: replace this with existing method
y_star = np.inf # minimizing optimizer
if len(self.observed_configs):
y_star = self.observed_configs.perf.values.min()
if len(self.max_budget_configs):
y_star = self.max_budget_configs.perf.values.min()
return y_star

def _get_rung_map(self, s: int = 0) -> dict:
Expand Down Expand Up @@ -325,7 +342,7 @@ def _get_rungs_state(self, observed_configs=None):
"""Collects info on configs at a rung and their performance there."""
# to account for incomplete evaluations from being promoted --> working on a copy
observed_configs = (
self.observed_configs.copy().dropna(inplace=False)
self.max_budget_configs.copy().dropna(inplace=False)
if observed_configs is None
else observed_configs
)
Expand Down Expand Up @@ -400,9 +417,9 @@ def load_results(

# TODO: change this after testing
# Copy data into old format
self.observed_configs = self.MFobserved_configs.copy_df(
df=self.MFobserved_configs.reduce_to_max_seen_budgets()
)
# self.max_budget_configs = self.MFobserved_configs.copy_df(
# df=self.MFobserved_configs.reduce_to_max_seen_budgets()
# )

# process optimization state and bucket observations per rung
self._get_rungs_state()
Expand Down Expand Up @@ -470,7 +487,7 @@ def get_config_and_ids( # pylint: disable=no-self-use
if rung_to_promote is not None:
# promotes the first recorded promotable config in the argsort-ed rung
# TODO: What to do with this?
row = self.observed_configs.iloc[self.rung_promotions[rung_to_promote][0]]
row = self.max_budget_configs.iloc[self.rung_promotions[rung_to_promote][0]]
config = deepcopy(row["config"])
rung = rung_to_promote + 1
# assigning the fidelity to evaluate the config at
Expand All @@ -484,7 +501,7 @@ def get_config_and_ids( # pylint: disable=no-self-use
if (
self.use_priors
and self.sample_default_first
and len(self.observed_configs) == 0
and len(self.max_budget_configs) == 0
):
if self.sample_default_at_target:
# sets the default config to be evaluated at the target fidelity
Expand Down Expand Up @@ -568,15 +585,15 @@ def clear_old_brackets(self):
start += 1
end += 1
# iterates over the different SH brackets which span start-end by index
while end <= len(self.observed_configs):
while end <= len(self.max_budget_configs):
# for the SH bracket in start-end, calculate total SH budget used
bracket_budget_used = self._calc_budget_used_in_bracket(
deepcopy(self.observed_configs.rung.values[start:end])
deepcopy(self.max_budget_configs.rung.values[start:end])
)
# if budget used is less than a SH bracket budget then still an active bracket
if bracket_budget_used < sum(self.full_rung_trace):
# subsetting only this SH bracket from the history
self._get_rungs_state(self.observed_configs.iloc[start:end])
self._get_rungs_state(self.max_budget_configs.iloc[start:end])
# extra call to use the updated rung member info to find promotions
# SyncPromotion signals a wait if a rung is full but with
# incomplete/pending evaluations, and signals to starts a new SH bracket
Expand All @@ -594,7 +611,7 @@ def clear_old_brackets(self):
end = start + self.config_map[self.min_rung]

# updates rung info with the latest active, incomplete bracket
self._get_rungs_state(self.observed_configs.iloc[start:end])
self._get_rungs_state(self.max_budget_configs.iloc[start:end])
# _handle_promotion() need not be called as it is called by load_results()
return

Expand Down
4 changes: 2 additions & 2 deletions neps/optimizers/multi_fidelity_prior/async_priorband.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import typing
from typing import Literal

import numpy as np
from typing import Literal

from ...metahyper import ConfigResult
from ...search_spaces.search_space import SearchSpace
Expand Down Expand Up @@ -238,7 +238,7 @@ def _update_sh_bracket_state(self) -> None:
config_map=bracket.config_map,
)
bracket.rung_promotions = bracket.promotion_policy.retrieve_promotions()
bracket.observed_configs = self.observed_configs.copy()
bracket.max_budget_configs = self.max_budget_configs.copy()
bracket.rung_histories = self.rung_histories

def load_results(
Expand Down
22 changes: 12 additions & 10 deletions neps/optimizers/multi_fidelity_prior/priorband.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from __future__ import annotations

import typing
from typing import Literal

import numpy as np
from typing import Literal

from ...search_spaces.search_space import SearchSpace
from ..bayesian_optimization.acquisition_functions.base_acquisition import BaseAcquisition
Expand Down Expand Up @@ -34,7 +34,7 @@ def find_all_distances_from_incumbent(self, incumbent):
"""Finds the distance to the nearest neighbour."""
dist = lambda x: compute_config_dist(incumbent, x)
# computing distance of incumbent from all seen points in history
distances = [dist(config) for config in self.observed_configs.config]
distances = [dist(config) for config in self.max_budget_configs.config]
# ensuring the distances exclude 0 or the distance from itself
distances = [d for d in distances if d > 0]
return distances
Expand All @@ -47,14 +47,14 @@ def find_1nn_distance_from_incumbent(self, incumbent):

def find_incumbent(self, rung: int = None) -> SearchSpace:
"""Find the best performing configuration seen so far."""
rungs = self.observed_configs.rung.values
idxs = self.observed_configs.index.values
rungs = self.max_budget_configs.rung.values
idxs = self.max_budget_configs.index.values
while rung is not None:
# enters this scope is `rung` argument passed and not left empty or None
if rung not in rungs:
self.logger.warn(f"{rung} not in {np.unique(idxs)}")
# filtering by rung based on argument passed
idxs = self.observed_configs.rung.values == rung
idxs = self.max_budget_configs.rung.values == rung
# checking width of current rung
if len(idxs) < self.eta:
self.logger.warn(
Expand All @@ -63,9 +63,9 @@ def find_incumbent(self, rung: int = None) -> SearchSpace:
# extracting the incumbent configuration
if len(idxs):
# finding the config with the lowest recorded performance
_perfs = self.observed_configs.loc[idxs].perf.values
_perfs = self.max_budget_configs.loc[idxs].perf.values
inc_idx = np.nanargmin([np.nan if t is None else t for t in _perfs])
inc = self.observed_configs.loc[idxs].iloc[inc_idx].config
inc = self.max_budget_configs.loc[idxs].iloc[inc_idx].config
else:
# THIS block should not ever execute, but for runtime anomalies, if no
# incumbent can be extracted, the prior is treated as the incumbent
Expand Down Expand Up @@ -126,7 +126,9 @@ def is_activate_inc(self) -> bool:
resources += bracket.config_map[rung] * continuation_resources

# find resources spent so far for all finished evaluations
resources_used = calc_total_resources_spent(self.observed_configs, self.rung_map)
resources_used = calc_total_resources_spent(
self.max_budget_configs, self.rung_map
)

if resources_used >= resources and len(
self.rung_histories[self.max_rung]["config"]
Expand Down Expand Up @@ -190,7 +192,7 @@ def prior_to_incumbent_ratio(self) -> float | float:
if self.inc_style == "constant":
return self._prior_to_incumbent_ratio_constant()
elif self.inc_style == "decay":
resources = calc_total_resources_spent(self.observed_configs, self.rung_map)
resources = calc_total_resources_spent(self.max_budget_configs, self.rung_map)
return self._prior_to_incumbent_ratio_decay(
resources, self.eta, self.min_budget, self.max_budget
)
Expand Down Expand Up @@ -244,7 +246,7 @@ def _prior_to_incumbent_ratio_dynamic(self, rung: int) -> float | float:
[
# `compute_scores` returns a tuple of scores resp. by prior and inc
compute_scores(
self.observed_configs.loc[config_id].config, prior, inc
self.max_budget_configs.loc[config_id].config, prior, inc
)
for config_id in top_configs
]
Expand Down

0 comments on commit e824c3f

Please sign in to comment.