From 8a1056417f00501a5f059805edb5f5f71eeabf48 Mon Sep 17 00:00:00 2001 From: karibbov Date: Fri, 15 Sep 2023 17:03:53 +0200 Subject: [PATCH] DeepGP checkpoint and refine --- .../efficiency/multi_fidelity_dyhpo.py | 11 +- .../bayesian_optimization/models/deepGP.py | 105 +++++++++++++++++- 2 files changed, 113 insertions(+), 3 deletions(-) diff --git a/neps_examples/efficiency/multi_fidelity_dyhpo.py b/neps_examples/efficiency/multi_fidelity_dyhpo.py index a355abc2..3aa463e1 100644 --- a/neps_examples/efficiency/multi_fidelity_dyhpo.py +++ b/neps_examples/efficiency/multi_fidelity_dyhpo.py @@ -83,10 +83,17 @@ def run_pipeline(pipeline_directory, previous_pipeline_directory, learning_rate, searcher="mf_ei_bo", # Optional: Do not start another evaluation after <=100 epochs, corresponds to cost # field above. - max_cost_total=50, + max_cost_total=60, surrogate_model="deep_gp", # Normalizing y here since we return unbounded loss, not completely correct to do so surrogate_model_args={ - "surrogate_model_fit_args": {"normalize_y": True}, + "surrogate_model_fit_args": { + "normalize_y": True, + "batch_size": 8, + "early_stopping": True, + }, + "checkpointing": True, + "root_directory": "results/multi_fidelity_example", }, + step_size=3, ) diff --git a/src/neps/optimizers/bayesian_optimization/models/deepGP.py b/src/neps/optimizers/bayesian_optimization/models/deepGP.py index ca1ebd6b..49fbe74c 100644 --- a/src/neps/optimizers/bayesian_optimization/models/deepGP.py +++ b/src/neps/optimizers/bayesian_optimization/models/deepGP.py @@ -1,6 +1,9 @@ from __future__ import annotations import logging +import os +from copy import deepcopy +from pathlib import Path import gpytorch import numpy as np @@ -15,6 +18,32 @@ ) +def count_non_improvement_steps(root_directory: Path | str) -> int: + root_directory = Path(root_directory) + + all_losses_file = root_directory / "all_losses_and_configs.txt" + best_loss_fiel = root_directory / "best_loss_trajectory.txt" + + # Read all losses from the file in the order they are explored + losses = [ + float(line[6:]) + for line in all_losses_file.read_text(encoding="utf-8").splitlines() + if "Loss: " in line + ] + # Get the best seen loss value + best_loss = float(best_loss_fiel.read_text(encoding="utf-8").splitlines()[-1].strip()) + + # Count the non-improvement + count = 0 + for loss in reversed(losses): + if np.greater(loss, best_loss): + count += 1 + else: + break + + return count + + class NeuralFeatureExtractor(nn.Module): """ Neural network to be used in the DeepGP @@ -134,11 +163,27 @@ def __init__( neural_network_args: dict | None = None, logger=None, surrogate_model_fit_args: dict | None = None, + # IMPORTANT: Checkpointing does not use file locking, + # IMPORTANT: hence, it is not suitable for multiprocessing settings + checkpointing: bool = False, + root_directory: Path | str | None = None, + checkpoint_file: Path | str = "surrogate_checkpoint.pth", + refine_epochs: int = 50, **kwargs, # pylint: disable=unused-argument ): self.surrogate_model_fit_args = ( surrogate_model_fit_args if surrogate_model_fit_args is not None else {} ) + + self.checkpointing = checkpointing + self.refine_epochs = refine_epochs + if checkpointing: + assert ( + root_directory is not None + ), "neps root_directory must be provided for the checkpointing" + self.root_dir = Path(os.getcwd(), root_directory) + self.checkpoint_path = Path(os.getcwd(), root_directory, checkpoint_file) + super().__init__() self.__preprocess_search_space(pipeline_space) # set the categories array for the encoder @@ -150,7 +195,7 @@ def __init__( self.device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) - self.device = torch.device("cpu") + # self.device = torch.device("cpu") # Save the NN args, necessary for preprocessing self.cnn_kernel_size = neural_network_args.get("cnn_kernel_size", 3) @@ -339,6 +384,7 @@ def _fit( optimizer_args: dict | None = None, early_stopping: bool = True, patience: int = 10, + perf_patience: int = 10, ): self.__reset_xy( x_train, @@ -348,6 +394,15 @@ def _fit( normalize_budget=normalize_budget, ) + if self.checkpointing: + non_improvement_steps = count_non_improvement_steps(self.root_dir) + # If checkpointing and patience is not exhausted load a partial model + if self.checkpoint_path.exists() and non_improvement_steps < perf_patience: + n_epochs = self.refine_epochs + self.load_checkpoint(self.checkpoint_path) + self.logger.info(f"No improvement: {non_improvement_steps}") + self.logger.info(f"N Epochs: {n_epochs}") + self.model.to(self.device) self.likelihood.to(self.device) self.nn.to(self.device) @@ -363,6 +418,7 @@ def _fit( early_stopping=early_stopping, patience=patience, ) + self.save_checkpoint(self.checkpoint_path) def __train_model( self, @@ -522,6 +578,53 @@ def predict( return means, cov + def load_checkpoint(self, checkpoint_path: str | Path): + """ + Load the state from a previous checkpoint. + """ + checkpoint = torch.load(checkpoint_path) + self.model.load_state_dict(checkpoint["gp_state_dict"]) + self.nn.load_state_dict(checkpoint["nn_state_dict"]) + self.likelihood.load_state_dict(checkpoint["likelihood_state_dict"]) + + def save_checkpoint(self, checkpoint_path: str | Path, state: dict | None = None): + """ + Save the given state or the current state in a + checkpoint file. + + Args: + checkpoint_path: path to the checkpoint file + state: The state to save, if none, it will + save the current state. + """ + + if state is None: + torch.save( + self.get_state(), + checkpoint_path, + ) + else: + torch.save( + state, + checkpoint_path, + ) + + def get_state(self) -> dict[str, dict]: + """ + Get the current state of the surrogate. + + Returns: + current_state: A dictionary that represents + the current state of the surrogate model. + """ + current_state = { + "gp_state_dict": deepcopy(self.model.state_dict()), + "nn_state_dict": deepcopy(self.nn.state_dict()), + "likelihood_state_dict": deepcopy(self.likelihood.state_dict()), + } + + return current_state + if __name__ == "__main__": print(torch.version.__version__)