Skip to content

Commit

Permalink
model cache for faster evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigo-arenas committed Sep 12, 2024
1 parent 4c3dd54 commit 8b6bf04
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
scikit-learn>=1.1.0
scikit-learn>=1.3.0
deap>=1.3.3
numpy>=1.19.0
pytest==7.4.0
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
addopts = --verbose --color=yes --assert=plain --cov-fail-under=95 --cov-config=.coveragerc --cov=./ -p no:warnings --tb=short --cov-report=term-missing:skip-covered
44 changes: 42 additions & 2 deletions sklearn_genetic/genetic_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def __init__(
self.return_train_score = return_train_score
self.creator = creator
self.log_config = log_config
self.fitness_cache = {}

# Check that the estimator is compatible with scikit-learn
if not is_classifier(self.estimator) and not is_regressor(self.estimator):
Expand Down Expand Up @@ -392,6 +393,17 @@ def evaluate(self, individual):
key: individual[n] for n, key in enumerate(self.space.parameters)
}

# Convert hyperparameters to a tuple to use as a key in the cache
individual_key = tuple(sorted(current_generation_params.items()))

# Check if the individual has already been evaluated
if individual_key in self.fitness_cache:
# Retrieve cached result
cached_result = self.fitness_cache[individual_key]
# Ensure the logbook is updated even if the individual is cached
self.logbook.record(parameters=cached_result["current_generation_params"])
return cached_result["fitness"]

local_estimator = clone(self.estimator)
local_estimator.set_params(**current_generation_params)

Expand Down Expand Up @@ -437,7 +449,15 @@ def evaluate(self, individual):
# Log the hyperparameters and the cv-score
self.logbook.record(parameters=current_generation_params)

return [score]
fitness_result = [score]

# Store the fitness result and the current generation parameters in the cache
self.fitness_cache[individual_key] = {
"fitness": fitness_result,
"current_generation_params": current_generation_params
}

return fitness_result

def fit(self, X, y, callbacks=None):
"""
Expand Down Expand Up @@ -880,6 +900,7 @@ def __init__(
self.return_train_score = return_train_score
self.creator = creator
self.log_config = log_config
self.fitness_cache = {}

# Check that the estimator is compatible with scikit-learn
if not is_classifier(self.estimator) and not is_regressor(self.estimator):
Expand Down Expand Up @@ -965,6 +986,16 @@ def evaluate(self, individual):
local_estimator = clone(self.estimator)
n_selected_features = np.sum(individual)

# Convert the individual to a tuple to use as a key in the cache
individual_key = tuple(individual)

# Check if the individual has already been evaluated
if individual_key in self.fitness_cache:
cached_result = self.fitness_cache[individual_key]
# Ensure the logbook is updated even if the individual is cached
self.logbook.record(parameters=cached_result["current_generation_features"])
return cached_result["fitness"]

# Compute the cv-metrics using only the selected features
cv_results = cross_validate(
local_estimator,
Expand Down Expand Up @@ -1014,7 +1045,16 @@ def evaluate(self, individual):
):
score = -self.criteria_sign * 100000

return [score, n_selected_features]
# Prepare the fitness result
fitness_result = [score, n_selected_features]

# Store the fitness result and the current generation features in the cache
self.fitness_cache[individual_key] = {
"fitness": fitness_result,
"current_generation_features": current_generation_features
}

return fitness_result

def fit(self, X, y, callbacks=None):
"""
Expand Down

0 comments on commit 8b6bf04

Please sign in to comment.