From b17a5dea1b28c23e9f6916354a9042438b65386f Mon Sep 17 00:00:00 2001 From: Rodrigo Lopez Date: Fri, 4 Oct 2024 13:51:09 -0500 Subject: [PATCH 1/6] creating new callback --- sklearn_genetic/callbacks/model_checkpoint.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 sklearn_genetic/callbacks/model_checkpoint.py diff --git a/sklearn_genetic/callbacks/model_checkpoint.py b/sklearn_genetic/callbacks/model_checkpoint.py new file mode 100644 index 0000000..4053fe6 --- /dev/null +++ b/sklearn_genetic/callbacks/model_checkpoint.py @@ -0,0 +1,49 @@ +import pickle +from .base import BaseCallback +from .loggers import LogbookSaver +from copy import deepcopy + + +class ModelCheckpoint(BaseCallback): + def __init__(self, checkpoint_path, **dump_options): + self.checkpoint_path = checkpoint_path + self.dump_options = dump_options + + def on_step(self, record=None, logbook=None, estimator=None): + try: + if logbook is not None and len(logbook) > 0: + logbook_saver = LogbookSaver(self.checkpoint_path, **self.dump_options) # noqa + logbook_saver.on_step(record, logbook, estimator) + + estimator_state = { + 'estimator': estimator.estimator, + 'cv': estimator.cv, + 'scoring': estimator.scoring, + 'population_size': estimator.population_size, + 'generations': estimator.generations, + 'crossover_probability': estimator.crossover_probability, + 'mutation_probability': estimator.mutation_probability, + 'param_grid': estimator.param_grid, + 'algorithm': estimator.algorithm, + 'param_grid': estimator.param_grid, + } + checkpoint_data = { + 'estimator_state': estimator_state, + 'logbook': deepcopy(logbook) + } + with open(self.checkpoint_path, 'wb') as f: + pickle.dump(checkpoint_data, f) + print(f"Checkpoint save in {self.checkpoint_path}") + + except Exception as e: + print(f"Error saving checkpoint: {e}") + + def load(self): + """Load the model state from the checkpoint file.""" + try: + with open(self.checkpoint_path, 'rb') as f: + checkpoint_data = pickle.load(f) + return checkpoint_data + except Exception as e: + print(f"Error loading checkpoint: {e}") + return None From f6b038694b22fb8d933d03623676a9c86273fa32 Mon Sep 17 00:00:00 2001 From: Rodrigo Lopez Date: Fri, 4 Oct 2024 13:52:25 -0500 Subject: [PATCH 2/6] Addding save,load and modifying fit --- sklearn_genetic/genetic_search.py | 68 +++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/sklearn_genetic/genetic_search.py b/sklearn_genetic/genetic_search.py index 8decc0e..6c41152 100644 --- a/sklearn_genetic/genetic_search.py +++ b/sklearn_genetic/genetic_search.py @@ -30,6 +30,10 @@ from .utils.random import weighted_bool_individual from .utils.tools import cxUniform, mutFlipBit, novelty_scorer +import pickle +import os +from .callbacks.model_checkpoint import ModelCheckpoint + class GASearchCV(BaseSearchCV): """ @@ -524,6 +528,16 @@ def fit(self, X, y, callbacks=None): # Make sure the callbacks are valid self.callbacks = check_callback(callbacks) + # Load state if a checkpoint exists + for callback in self.callbacks: + if isinstance(callback, ModelCheckpoint): + if os.path.exists(callback.checkpoint_path): + checkpoint_data = callback.load() + if checkpoint_data: + self.estimator.__dict__.update(checkpoint_data['estimator_state']) # noqa + self.logbook = checkpoint_data['logbook'] + break + if callable(self.scoring): self.scorer_ = self.scoring self.metrics_list = [self.refit_metric] @@ -601,6 +615,28 @@ def fit(self, X, y, callbacks=None): return self + def save(self, filepath): + """Save the current state of the GASearchCV instance to a file.""" + try: + checkpoint_data = self.__dict__ + with open(filepath, 'wb') as f: + pickle.dump(checkpoint_data, f) + print(f"GASearchCV model successfully saved to {filepath}") + except Exception as e: + print(f"Error saving GASearchCV: {e}") + + @staticmethod + def load(filepath): + """Load a GASearchCV instance from a file.""" + try: + with open(filepath, 'rb') as f: + checkpoint_data = pickle.load(f) + model = GASearchCV(**checkpoint_data) + print(f"GASearchCV model successfully loaded from {filepath}") + return model + except Exception as e: + print(f"Error loading GASearchCV: {e}") + def _select_algorithm(self, pop, stats, hof): """ It selects the algorithm to run from the sklearn_genetic.algorithms module @@ -1131,6 +1167,16 @@ def fit(self, X, y, callbacks=None): # Make sure the callbacks are valid self.callbacks = check_callback(callbacks) + # Load state if a checkpoint exists + for callback in self.callbacks: + if isinstance(callback, ModelCheckpoint): + if os.path.exists(callback.checkpoint_path): + checkpoint_data = callback.load() + if checkpoint_data: + self.estimator.__dict__.update(checkpoint_data['estimator_state']) # noqa + self.logbook = checkpoint_data['logbook'] + break + if callable(self.scoring): self.scorer_ = self.scoring self.metrics_list = [self.refit_metric] @@ -1192,6 +1238,28 @@ def fit(self, X, y, callbacks=None): return self + def save(self, filepath): + """Save the current state of the GAFeatureSelectionCV instance to a file.""" + try: + checkpoint_data = self.__dict__ + with open(filepath, 'wb') as f: + pickle.dump(checkpoint_data, f) + print(f"GAFeatureSelectionCV model successfully saved to {filepath}") + except Exception as e: + print(f"Error saving GAFeatureSelectionCV: {e}") + + @staticmethod + def load(filepath): + """Load a GAFeatureSelectionCV instance from a file.""" + try: + with open(filepath, 'rb') as f: + checkpoint_data = pickle.load(f) + model = GAFeatureSelectionCV(**checkpoint_data) + print(f"GAFeatureSelectionCV model successfully loaded from {filepath}") # noqa + return model + except Exception as e: + print(f"Error loading GAFeatureSelectionCV: {e}") + def _select_algorithm(self, pop, stats, hof): """ It selects the algorithm to run from the sklearn_genetic.algorithms module From c5926d91556e9de7ae67ffed892dcce815714915 Mon Sep 17 00:00:00 2001 From: Rodrigo Lopez Date: Fri, 4 Oct 2024 13:52:48 -0500 Subject: [PATCH 3/6] adding callback to init --- sklearn_genetic/callbacks/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn_genetic/callbacks/__init__.py b/sklearn_genetic/callbacks/__init__.py index a3cbfac..d7d6a06 100644 --- a/sklearn_genetic/callbacks/__init__.py +++ b/sklearn_genetic/callbacks/__init__.py @@ -5,6 +5,7 @@ TimerStopping, ) from .loggers import ProgressBar, LogbookSaver, TensorBoard +from .model_checkpoint import ModelCheckpoint __all__ = [ "ProgressBar", @@ -14,4 +15,5 @@ "TimerStopping", "LogbookSaver", "TensorBoard", + "ModelCheckpoint" ] From 2e3930b3867239287d19f74a2ce75644eb536bb7 Mon Sep 17 00:00:00 2001 From: Rodrigo Lopez Date: Fri, 4 Oct 2024 13:53:05 -0500 Subject: [PATCH 4/6] adding a sample test --- sklearn_genetic/tests/test_genetic_search.py | 50 ++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/sklearn_genetic/tests/test_genetic_search.py b/sklearn_genetic/tests/test_genetic_search.py index 4f2bf04..ff2d9af 100644 --- a/sklearn_genetic/tests/test_genetic_search.py +++ b/sklearn_genetic/tests/test_genetic_search.py @@ -10,6 +10,7 @@ from sklearn.metrics import make_scorer import numpy as np +import os from .. import GASearchCV from ..space import Integer, Categorical, Continuous @@ -19,6 +20,7 @@ ConsecutiveStopping, TimerStopping, ProgressBar, + ModelCheckpoint ) from ..schedules import ExponentialAdapter, InverseAdapter @@ -659,3 +661,51 @@ def test_expected_ga_schedulers(): assert "params" in cv_result_keys assert crossover_scheduler.current_value + mutation_scheduler.current_value <= 1 + + +def test_checkpoint_functionality(): + clf = SGDClassifier(loss="modified_huber", fit_intercept=True) + gen = 5 + evolved_estimator = GASearchCV( + clf, + cv=3, + scoring="accuracy", + population_size=6, + generations=gen, + tournament_size=3, + param_grid={ + "l1_ratio": Continuous(0, 1), + "alpha": Continuous(1e-4, 1), + "average": Categorical([True, False]), + }, + ) + checkpoint_path = 'test_checkpoint.pkl' + checkpoint = ModelCheckpoint(checkpoint_path=checkpoint_path) # noqa + evolved_estimator.fit(X_train, y_train, callbacks=checkpoint) + + checkpoint_data = checkpoint.load() + + assert 'estimator' in checkpoint_data['estimator_state'] + assert 'algorithm' in checkpoint_data['estimator_state'] + assert 'logbook' in checkpoint_data + + restored_estimator = GASearchCV(**checkpoint_data['estimator_state']) + + assert restored_estimator.algorithm == checkpoint_data['estimator_state']['algorithm'] # noqa + + assert len(checkpoint_data['logbook']) == gen + 1 + + test_estimator = GASearchCV( + clf, + param_grid={ + "l1_ratio": Continuous(0, 1), + "alpha": Continuous(1e-1, 1), + "average": Categorical([True, False]), + },) + + test_estimator.load('checkpoint_path') + + assert restored_estimator.algorithm == test_estimator.algorithm # noqa + + if os.path.exists(checkpoint_path): + os.remove(checkpoint_path) \ No newline at end of file From 390cc2e847b560307a69afef28196d3ed4063a96 Mon Sep 17 00:00:00 2001 From: Rodrigo Lopez Date: Fri, 4 Oct 2024 14:00:32 -0500 Subject: [PATCH 5/6] fix format --- sklearn_genetic/callbacks/__init__.py | 2 +- sklearn_genetic/callbacks/model_checkpoint.py | 31 ++--- sklearn_genetic/genetic_search.py | 126 +++++++++--------- sklearn_genetic/tests/test_genetic_search.py | 31 +++-- 4 files changed, 94 insertions(+), 96 deletions(-) diff --git a/sklearn_genetic/callbacks/__init__.py b/sklearn_genetic/callbacks/__init__.py index d7d6a06..f030b40 100644 --- a/sklearn_genetic/callbacks/__init__.py +++ b/sklearn_genetic/callbacks/__init__.py @@ -15,5 +15,5 @@ "TimerStopping", "LogbookSaver", "TensorBoard", - "ModelCheckpoint" + "ModelCheckpoint", ] diff --git a/sklearn_genetic/callbacks/model_checkpoint.py b/sklearn_genetic/callbacks/model_checkpoint.py index 4053fe6..4a2f482 100644 --- a/sklearn_genetic/callbacks/model_checkpoint.py +++ b/sklearn_genetic/callbacks/model_checkpoint.py @@ -5,7 +5,7 @@ class ModelCheckpoint(BaseCallback): - def __init__(self, checkpoint_path, **dump_options): + def __init__(self, checkpoint_path, **dump_options): self.checkpoint_path = checkpoint_path self.dump_options = dump_options @@ -16,22 +16,19 @@ def on_step(self, record=None, logbook=None, estimator=None): logbook_saver.on_step(record, logbook, estimator) estimator_state = { - 'estimator': estimator.estimator, - 'cv': estimator.cv, - 'scoring': estimator.scoring, - 'population_size': estimator.population_size, - 'generations': estimator.generations, - 'crossover_probability': estimator.crossover_probability, - 'mutation_probability': estimator.mutation_probability, - 'param_grid': estimator.param_grid, - 'algorithm': estimator.algorithm, - 'param_grid': estimator.param_grid, + "estimator": estimator.estimator, + "cv": estimator.cv, + "scoring": estimator.scoring, + "population_size": estimator.population_size, + "generations": estimator.generations, + "crossover_probability": estimator.crossover_probability, + "mutation_probability": estimator.mutation_probability, + "param_grid": estimator.param_grid, + "algorithm": estimator.algorithm, + "param_grid": estimator.param_grid, } - checkpoint_data = { - 'estimator_state': estimator_state, - 'logbook': deepcopy(logbook) - } - with open(self.checkpoint_path, 'wb') as f: + checkpoint_data = {"estimator_state": estimator_state, "logbook": deepcopy(logbook)} + with open(self.checkpoint_path, "wb") as f: pickle.dump(checkpoint_data, f) print(f"Checkpoint save in {self.checkpoint_path}") @@ -41,7 +38,7 @@ def on_step(self, record=None, logbook=None, estimator=None): def load(self): """Load the model state from the checkpoint file.""" try: - with open(self.checkpoint_path, 'rb') as f: + with open(self.checkpoint_path, "rb") as f: checkpoint_data = pickle.load(f) return checkpoint_data except Exception as e: diff --git a/sklearn_genetic/genetic_search.py b/sklearn_genetic/genetic_search.py index 6c41152..0bbdf6f 100644 --- a/sklearn_genetic/genetic_search.py +++ b/sklearn_genetic/genetic_search.py @@ -222,29 +222,29 @@ class GASearchCV(BaseSearchCV): """ def __init__( - self, - estimator, - cv=3, - param_grid=None, - scoring=None, - population_size=50, - generations=80, - crossover_probability=0.2, - mutation_probability=0.8, - tournament_size=3, - elitism=True, - verbose=True, - keep_top_k=1, - criteria="max", - algorithm="eaMuPlusLambda", - refit=True, - n_jobs=1, - pre_dispatch="2*n_jobs", - error_score=np.nan, - return_train_score=False, - log_config=None, - use_cache=True, - warm_start_configs=None, + self, + estimator, + cv=3, + param_grid=None, + scoring=None, + population_size=50, + generations=80, + crossover_probability=0.2, + mutation_probability=0.8, + tournament_size=3, + elitism=True, + verbose=True, + keep_top_k=1, + criteria="max", + algorithm="eaMuPlusLambda", + refit=True, + n_jobs=1, + pre_dispatch="2*n_jobs", + error_score=np.nan, + return_train_score=False, + log_config=None, + use_cache=True, + warm_start_configs=None, ): self.estimator = estimator self.cv = cv @@ -315,7 +315,6 @@ def _register(self): creator.create("FitnessMax", base.Fitness, weights=[self.criteria_sign, 1.0]) creator.create("Individual", list, fitness=creator.FitnessMax) - attributes = [] # Assign all the parameters defined in the param_grid # It uses the distribution parameter to set the sampling function @@ -362,7 +361,6 @@ def _register(self): self._stats.register("fitness_max", np.max, axis=0) self._stats.register("fitness_min", np.min, axis=0) - self.logbook = tools.Logbook() def _initialize_population(self): @@ -494,7 +492,7 @@ def evaluate(self, individual): # Store the fitness result and the current generation parameters in the cache self.fitness_cache[individual_key] = { "fitness": fitness_result, - "current_generation_params": current_generation_params + "current_generation_params": current_generation_params, } return fitness_result @@ -531,11 +529,11 @@ def fit(self, X, y, callbacks=None): # Load state if a checkpoint exists for callback in self.callbacks: if isinstance(callback, ModelCheckpoint): - if os.path.exists(callback.checkpoint_path): + if os.path.exists(callback.checkpoint_path): checkpoint_data = callback.load() if checkpoint_data: - self.estimator.__dict__.update(checkpoint_data['estimator_state']) # noqa - self.logbook = checkpoint_data['logbook'] + self.estimator.__dict__.update(checkpoint_data["estimator_state"]) # noqa + self.logbook = checkpoint_data["logbook"] break if callable(self.scoring): @@ -619,17 +617,17 @@ def save(self, filepath): """Save the current state of the GASearchCV instance to a file.""" try: checkpoint_data = self.__dict__ - with open(filepath, 'wb') as f: + with open(filepath, "wb") as f: pickle.dump(checkpoint_data, f) print(f"GASearchCV model successfully saved to {filepath}") except Exception as e: print(f"Error saving GASearchCV: {e}") - + @staticmethod def load(filepath): """Load a GASearchCV instance from a file.""" try: - with open(filepath, 'rb') as f: + with open(filepath, "rb") as f: checkpoint_data = pickle.load(f) model = GASearchCV(**checkpoint_data) print(f"GASearchCV model successfully loaded from {filepath}") @@ -931,28 +929,28 @@ class GAFeatureSelectionCV(MetaEstimatorMixin, SelectorMixin, BaseEstimator): """ def __init__( - self, - estimator, - cv=3, - scoring=None, - population_size=50, - generations=80, - crossover_probability=0.2, - mutation_probability=0.8, - tournament_size=3, - elitism=True, - max_features=None, - verbose=True, - keep_top_k=1, - criteria="max", - algorithm="eaMuPlusLambda", - refit=True, - n_jobs=1, - pre_dispatch="2*n_jobs", - error_score=np.nan, - return_train_score=False, - log_config=None, - use_cache=True, + self, + estimator, + cv=3, + scoring=None, + population_size=50, + generations=80, + crossover_probability=0.2, + mutation_probability=0.8, + tournament_size=3, + elitism=True, + max_features=None, + verbose=True, + keep_top_k=1, + criteria="max", + algorithm="eaMuPlusLambda", + refit=True, + n_jobs=1, + pre_dispatch="2*n_jobs", + error_score=np.nan, + return_train_score=False, + log_config=None, + use_cache=True, ): self.estimator = estimator self.cv = cv @@ -1119,7 +1117,7 @@ def evaluate(self, individual): # Penalize individuals with more features than the max_features parameter if self.max_features and ( - n_selected_features > self.max_features or n_selected_features == 0 + n_selected_features > self.max_features or n_selected_features == 0 ): score = -self.criteria_sign * 100000 @@ -1130,7 +1128,7 @@ def evaluate(self, individual): # Store the fitness result and the current generation features in the cache self.fitness_cache[individual_key] = { "fitness": fitness_result, - "current_generation_features": current_generation_features + "current_generation_features": current_generation_features, } return fitness_result @@ -1170,13 +1168,13 @@ def fit(self, X, y, callbacks=None): # Load state if a checkpoint exists for callback in self.callbacks: if isinstance(callback, ModelCheckpoint): - if os.path.exists(callback.checkpoint_path): + if os.path.exists(callback.checkpoint_path): checkpoint_data = callback.load() if checkpoint_data: - self.estimator.__dict__.update(checkpoint_data['estimator_state']) # noqa - self.logbook = checkpoint_data['logbook'] + self.estimator.__dict__.update(checkpoint_data["estimator_state"]) # noqa + self.logbook = checkpoint_data["logbook"] break - + if callable(self.scoring): self.scorer_ = self.scoring self.metrics_list = [self.refit_metric] @@ -1241,8 +1239,8 @@ def fit(self, X, y, callbacks=None): def save(self, filepath): """Save the current state of the GAFeatureSelectionCV instance to a file.""" try: - checkpoint_data = self.__dict__ - with open(filepath, 'wb') as f: + checkpoint_data = self.__dict__ + with open(filepath, "wb") as f: pickle.dump(checkpoint_data, f) print(f"GAFeatureSelectionCV model successfully saved to {filepath}") except Exception as e: @@ -1252,10 +1250,10 @@ def save(self, filepath): def load(filepath): """Load a GAFeatureSelectionCV instance from a file.""" try: - with open(filepath, 'rb') as f: + with open(filepath, "rb") as f: checkpoint_data = pickle.load(f) model = GAFeatureSelectionCV(**checkpoint_data) - print(f"GAFeatureSelectionCV model successfully loaded from {filepath}") # noqa + print(f"GAFeatureSelectionCV model successfully loaded from {filepath}") # noqa return model except Exception as e: print(f"Error loading GAFeatureSelectionCV: {e}") diff --git a/sklearn_genetic/tests/test_genetic_search.py b/sklearn_genetic/tests/test_genetic_search.py index ff2d9af..c42f0f0 100644 --- a/sklearn_genetic/tests/test_genetic_search.py +++ b/sklearn_genetic/tests/test_genetic_search.py @@ -20,7 +20,7 @@ ConsecutiveStopping, TimerStopping, ProgressBar, - ModelCheckpoint + ModelCheckpoint, ) from ..schedules import ExponentialAdapter, InverseAdapter @@ -607,8 +607,10 @@ def test_expected_ga_schedulers(): "average": Categorical([True, False]), "max_iter": Integer(700, 1000), }, - warm_start_configs=[{"l1_ratio": 0.5, "alpha": 0.5, "average": False, "max_iter": 400}, - {"l1_ratio": 0.2, "alpha": 0.8, "average": True, "max_iter": 400}], + warm_start_configs=[ + {"l1_ratio": 0.5, "alpha": 0.5, "average": False, "max_iter": 400}, + {"l1_ratio": 0.2, "alpha": 0.8, "average": True, "max_iter": 400}, + ], verbose=False, algorithm="eaSimple", n_jobs=-1, @@ -679,21 +681,21 @@ def test_checkpoint_functionality(): "average": Categorical([True, False]), }, ) - checkpoint_path = 'test_checkpoint.pkl' + checkpoint_path = "test_checkpoint.pkl" checkpoint = ModelCheckpoint(checkpoint_path=checkpoint_path) # noqa evolved_estimator.fit(X_train, y_train, callbacks=checkpoint) checkpoint_data = checkpoint.load() - assert 'estimator' in checkpoint_data['estimator_state'] - assert 'algorithm' in checkpoint_data['estimator_state'] - assert 'logbook' in checkpoint_data + assert "estimator" in checkpoint_data["estimator_state"] + assert "algorithm" in checkpoint_data["estimator_state"] + assert "logbook" in checkpoint_data - restored_estimator = GASearchCV(**checkpoint_data['estimator_state']) + restored_estimator = GASearchCV(**checkpoint_data["estimator_state"]) - assert restored_estimator.algorithm == checkpoint_data['estimator_state']['algorithm'] # noqa - - assert len(checkpoint_data['logbook']) == gen + 1 + assert restored_estimator.algorithm == checkpoint_data["estimator_state"]["algorithm"] # noqa + + assert len(checkpoint_data["logbook"]) == gen + 1 test_estimator = GASearchCV( clf, @@ -701,11 +703,12 @@ def test_checkpoint_functionality(): "l1_ratio": Continuous(0, 1), "alpha": Continuous(1e-1, 1), "average": Categorical([True, False]), - },) + }, + ) - test_estimator.load('checkpoint_path') + test_estimator.load("checkpoint_path") assert restored_estimator.algorithm == test_estimator.algorithm # noqa if os.path.exists(checkpoint_path): - os.remove(checkpoint_path) \ No newline at end of file + os.remove(checkpoint_path) From 39baa19b9db2f066f81fdb2bbd121b0f3ffa4e56 Mon Sep 17 00:00:00 2001 From: Rodrigo Lopez Date: Fri, 4 Oct 2024 15:18:14 -0500 Subject: [PATCH 6/6] fix in fit,save and load --- sklearn_genetic/genetic_search.py | 29 ++++++++++++-------- sklearn_genetic/tests/test_genetic_search.py | 8 ++++-- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/sklearn_genetic/genetic_search.py b/sklearn_genetic/genetic_search.py index 0bbdf6f..fd8d434 100644 --- a/sklearn_genetic/genetic_search.py +++ b/sklearn_genetic/genetic_search.py @@ -532,7 +532,7 @@ def fit(self, X, y, callbacks=None): if os.path.exists(callback.checkpoint_path): checkpoint_data = callback.load() if checkpoint_data: - self.estimator.__dict__.update(checkpoint_data["estimator_state"]) # noqa + self.__dict__.update(checkpoint_data["estimator_state"]) # noqa self.logbook = checkpoint_data["logbook"] break @@ -616,22 +616,24 @@ def fit(self, X, y, callbacks=None): def save(self, filepath): """Save the current state of the GASearchCV instance to a file.""" try: - checkpoint_data = self.__dict__ + checkpoint_data = {"estimator_state": self.__dict__, "logbook": None} + if hasattr(self, "logbook"): + checkpoint_data["logbook"] = self.logbook with open(filepath, "wb") as f: pickle.dump(checkpoint_data, f) print(f"GASearchCV model successfully saved to {filepath}") except Exception as e: print(f"Error saving GASearchCV: {e}") - @staticmethod - def load(filepath): + def load(self, filepath): """Load a GASearchCV instance from a file.""" try: with open(filepath, "rb") as f: checkpoint_data = pickle.load(f) - model = GASearchCV(**checkpoint_data) + for key, value in checkpoint_data["estimator_state"].items(): + setattr(self, key, value) + self.logbook = checkpoint_data["logbook"] print(f"GASearchCV model successfully loaded from {filepath}") - return model except Exception as e: print(f"Error loading GASearchCV: {e}") @@ -1171,7 +1173,7 @@ def fit(self, X, y, callbacks=None): if os.path.exists(callback.checkpoint_path): checkpoint_data = callback.load() if checkpoint_data: - self.estimator.__dict__.update(checkpoint_data["estimator_state"]) # noqa + self.__dict__.update(checkpoint_data["estimator_state"]) # noqa self.logbook = checkpoint_data["logbook"] break @@ -1239,22 +1241,25 @@ def fit(self, X, y, callbacks=None): def save(self, filepath): """Save the current state of the GAFeatureSelectionCV instance to a file.""" try: - checkpoint_data = self.__dict__ + checkpoint_data = {"estimator_state": self.__dict__, "logbook": None} + if hasattr(self, "logbook"): + checkpoint_data["logbook"] = self.logbook + with open(filepath, "wb") as f: pickle.dump(checkpoint_data, f) print(f"GAFeatureSelectionCV model successfully saved to {filepath}") except Exception as e: print(f"Error saving GAFeatureSelectionCV: {e}") - @staticmethod - def load(filepath): + def load(self, filepath): """Load a GAFeatureSelectionCV instance from a file.""" try: with open(filepath, "rb") as f: checkpoint_data = pickle.load(f) - model = GAFeatureSelectionCV(**checkpoint_data) + for key, value in checkpoint_data["estimator_state"].items(): + setattr(self, key, value) + self.logbook = checkpoint_data["logbook"] print(f"GAFeatureSelectionCV model successfully loaded from {filepath}") # noqa - return model except Exception as e: print(f"Error loading GAFeatureSelectionCV: {e}") diff --git a/sklearn_genetic/tests/test_genetic_search.py b/sklearn_genetic/tests/test_genetic_search.py index c42f0f0..6c0dcdc 100644 --- a/sklearn_genetic/tests/test_genetic_search.py +++ b/sklearn_genetic/tests/test_genetic_search.py @@ -697,18 +697,22 @@ def test_checkpoint_functionality(): assert len(checkpoint_data["logbook"]) == gen + 1 + restored_estimator.save(checkpoint_path) + test_estimator = GASearchCV( clf, param_grid={ "l1_ratio": Continuous(0, 1), "alpha": Continuous(1e-1, 1), - "average": Categorical([True, False]), + "average": Categorical([False, True]), }, ) - test_estimator.load("checkpoint_path") + test_estimator.load(checkpoint_path) assert restored_estimator.algorithm == test_estimator.algorithm # noqa + assert restored_estimator.scoring == test_estimator.scoring # noqa + assert restored_estimator.generations == test_estimator.generations # noqa if os.path.exists(checkpoint_path): os.remove(checkpoint_path)