Skip to content

Commit

Permalink
Added return values
Browse files Browse the repository at this point in the history
  • Loading branch information
karibbov committed Nov 6, 2023
1 parent c5d29ab commit fe782b6
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions src/neps/optimizers/multi_fidelity/mf_bo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

from copy import deepcopy

import numpy as np
import pandas as pd
import torch
Expand All @@ -15,8 +16,8 @@


class MFBOBase:
""" Designed to work with model-based search on SH-based multi-fidelity algorithms.
"""Designed to work with model-based search on SH-based multi-fidelity algorithms.
Requires certain strict assumptions about fidelities and rung maps.
"""

Expand Down Expand Up @@ -182,7 +183,7 @@ def sample_new_config(


class FreezeThawModel:
""" Designed to work with model search in unit step multi-fidelity algorithms."""
"""Designed to work with model search in unit step multi-fidelity algorithms."""

def __init__(
self,
Expand Down Expand Up @@ -240,29 +241,34 @@ def _fit(self, train_x, train_y, train_lcs):
raise ValueError(
f"Surrogate model {self.surrogate_model_name} not supported!"
)

def _predict(self, test_x, test_lcs):
if self.surrogate_model_name in ["gp", "gp_hierarchy"]:
self.surrogate_model.predict(test_x)
return self.surrogate_model.predict(test_x)
elif self.surrogate_model_name in ["deep_gp", "pfn"]:
self.surrogate_model.predict(test_x, test_lcs)
return self.surrogate_model.predict(test_x, test_lcs)
else:
# check neps/optimizers/bayesian_optimization/models/__init__.py for options
raise ValueError(
f"Surrogate model {self.surrogate_model_name} not supported!"
)

def set_state(self, pipeline_space, surrogate_model_args, **kwargs):
def set_state(
self,
pipeline_space,
surrogate_model_args,
**kwargs, # pylint: disable=unused-argument
):
self.pipeline_space = pipeline_space
self.surrogate_model_args = (
surrogate_model_args if surrogate_model_args is not None else {}
)
# only to handle tabular spaces
if self.pipeline_space.has_tabular:
if self.surrogate_model_name in ["deep_gp", "pfn"]:
self.surrogate_model_args.update({
"pipeline_space": self.pipeline_space.raw_tabular_space
})
self.surrogate_model_args.update(
{"pipeline_space": self.pipeline_space.raw_tabular_space}
)
# instantiate the surrogate model, again, with the new pipeline space
self.surrogate_model = instance_from_map(
SurrogateModelMapping,
Expand All @@ -289,12 +295,13 @@ def update_model(self, train_x=None, train_y=None, pending_x=None, decay_t=None)

class PFNSurrogate(FreezeThawModel):
"""Special class to deal with PFN surrogate model and freeze-thaw acquisition."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.train_x = None
self.train_y = None

def _fit(self, *args):
def _fit(self, *args): # pylint: disable=unused-argument
assert self.surrogate_model_name == "pfn"
self.preprocess_training_set()
self.surrogate_model.fit(self.train_x, self.train_y)
Expand All @@ -310,7 +317,7 @@ def preprocess_training_set(self):
_configs = map_real_hyperparameters_from_tabular_ids(
pd.Series(_configs, index=_idxs), self.pipeline_space
).values

device = self.surrogate_model.device
# TODO: fix or make consistent with `tokenize``
configs, idxs, performances = self.observed_configs.get_tokenized_data(
Expand All @@ -326,9 +333,9 @@ def preprocess_test_set(self, test_x):

new_idxs = np.arange(_len, len(test_x))
base_fidelity = np.array([1] * len(new_idxs))
new_token_ids = np.hstack((
new_idxs.T.reshape(-1, 1), base_fidelity.T.reshape(-1, 1)
))
new_token_ids = np.hstack(
(new_idxs.T.reshape(-1, 1), base_fidelity.T.reshape(-1, 1))
)
# the following operation takes each element in the array and stacks it vertically
# in this case, should convert a (n,) array to (n, 2) by flattening the elements
existing_token_ids = np.vstack(self.observed_configs.token_ids).astype(int)
Expand Down

0 comments on commit fe782b6

Please sign in to comment.