Skip to content

Commit

Permalink
DeepGP checkpoint and refine
Browse files Browse the repository at this point in the history
  • Loading branch information
karibbov committed Sep 15, 2023
1 parent 11df8cb commit 8a10564
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 3 deletions.
11 changes: 9 additions & 2 deletions neps_examples/efficiency/multi_fidelity_dyhpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,17 @@ def run_pipeline(pipeline_directory, previous_pipeline_directory, learning_rate,
searcher="mf_ei_bo",
# Optional: Do not start another evaluation after <=100 epochs, corresponds to cost
# field above.
max_cost_total=50,
max_cost_total=60,
surrogate_model="deep_gp",
# Normalizing y here since we return unbounded loss, not completely correct to do so
surrogate_model_args={
"surrogate_model_fit_args": {"normalize_y": True},
"surrogate_model_fit_args": {
"normalize_y": True,
"batch_size": 8,
"early_stopping": True,
},
"checkpointing": True,
"root_directory": "results/multi_fidelity_example",
},
step_size=3,
)
105 changes: 104 additions & 1 deletion src/neps/optimizers/bayesian_optimization/models/deepGP.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import logging
import os
from copy import deepcopy
from pathlib import Path

import gpytorch
import numpy as np
Expand All @@ -15,6 +18,32 @@
)


def count_non_improvement_steps(root_directory: Path | str) -> int:
root_directory = Path(root_directory)

all_losses_file = root_directory / "all_losses_and_configs.txt"
best_loss_fiel = root_directory / "best_loss_trajectory.txt"

# Read all losses from the file in the order they are explored
losses = [
float(line[6:])
for line in all_losses_file.read_text(encoding="utf-8").splitlines()
if "Loss: " in line
]
# Get the best seen loss value
best_loss = float(best_loss_fiel.read_text(encoding="utf-8").splitlines()[-1].strip())

# Count the non-improvement
count = 0
for loss in reversed(losses):
if np.greater(loss, best_loss):
count += 1
else:
break

return count


class NeuralFeatureExtractor(nn.Module):
"""
Neural network to be used in the DeepGP
Expand Down Expand Up @@ -134,11 +163,27 @@ def __init__(
neural_network_args: dict | None = None,
logger=None,
surrogate_model_fit_args: dict | None = None,
# IMPORTANT: Checkpointing does not use file locking,
# IMPORTANT: hence, it is not suitable for multiprocessing settings
checkpointing: bool = False,
root_directory: Path | str | None = None,
checkpoint_file: Path | str = "surrogate_checkpoint.pth",
refine_epochs: int = 50,
**kwargs, # pylint: disable=unused-argument
):
self.surrogate_model_fit_args = (
surrogate_model_fit_args if surrogate_model_fit_args is not None else {}
)

self.checkpointing = checkpointing
self.refine_epochs = refine_epochs
if checkpointing:
assert (
root_directory is not None
), "neps root_directory must be provided for the checkpointing"
self.root_dir = Path(os.getcwd(), root_directory)
self.checkpoint_path = Path(os.getcwd(), root_directory, checkpoint_file)

super().__init__()
self.__preprocess_search_space(pipeline_space)
# set the categories array for the encoder
Expand All @@ -150,7 +195,7 @@ def __init__(
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
self.device = torch.device("cpu")
# self.device = torch.device("cpu")

# Save the NN args, necessary for preprocessing
self.cnn_kernel_size = neural_network_args.get("cnn_kernel_size", 3)
Expand Down Expand Up @@ -339,6 +384,7 @@ def _fit(
optimizer_args: dict | None = None,
early_stopping: bool = True,
patience: int = 10,
perf_patience: int = 10,
):
self.__reset_xy(
x_train,
Expand All @@ -348,6 +394,15 @@ def _fit(
normalize_budget=normalize_budget,
)

if self.checkpointing:
non_improvement_steps = count_non_improvement_steps(self.root_dir)
# If checkpointing and patience is not exhausted load a partial model
if self.checkpoint_path.exists() and non_improvement_steps < perf_patience:
n_epochs = self.refine_epochs
self.load_checkpoint(self.checkpoint_path)
self.logger.info(f"No improvement: {non_improvement_steps}")
self.logger.info(f"N Epochs: {n_epochs}")

self.model.to(self.device)
self.likelihood.to(self.device)
self.nn.to(self.device)
Expand All @@ -363,6 +418,7 @@ def _fit(
early_stopping=early_stopping,
patience=patience,
)
self.save_checkpoint(self.checkpoint_path)

def __train_model(
self,
Expand Down Expand Up @@ -522,6 +578,53 @@ def predict(

return means, cov

def load_checkpoint(self, checkpoint_path: str | Path):
"""
Load the state from a previous checkpoint.
"""
checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint["gp_state_dict"])
self.nn.load_state_dict(checkpoint["nn_state_dict"])
self.likelihood.load_state_dict(checkpoint["likelihood_state_dict"])

def save_checkpoint(self, checkpoint_path: str | Path, state: dict | None = None):
"""
Save the given state or the current state in a
checkpoint file.
Args:
checkpoint_path: path to the checkpoint file
state: The state to save, if none, it will
save the current state.
"""

if state is None:
torch.save(
self.get_state(),
checkpoint_path,
)
else:
torch.save(
state,
checkpoint_path,
)

def get_state(self) -> dict[str, dict]:
"""
Get the current state of the surrogate.
Returns:
current_state: A dictionary that represents
the current state of the surrogate model.
"""
current_state = {
"gp_state_dict": deepcopy(self.model.state_dict()),
"nn_state_dict": deepcopy(self.nn.state_dict()),
"likelihood_state_dict": deepcopy(self.likelihood.state_dict()),
}

return current_state


if __name__ == "__main__":
print(torch.version.__version__)
Expand Down

0 comments on commit 8a10564

Please sign in to comment.