Skip to content

Commit 5bdb85d

Browse files
warm_start_configs for model initialization
1 parent 741ef00 commit 5bdb85d

File tree

5 files changed

+82
-4
lines changed

5 files changed

+82
-4
lines changed

README.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,19 @@ Example: Hyperparameters Tuning
102102
103103
clf = RandomForestClassifier()
104104
105+
# Defines the possible values to search
105106
param_grid = {'min_weight_fraction_leaf': Continuous(0.01, 0.5, distribution='log-uniform'),
106107
'bootstrap': Categorical([True, False]),
107108
'max_depth': Integer(2, 30),
108109
'max_leaf_nodes': Integer(2, 35),
109110
'n_estimators': Integer(100, 300)}
110111
112+
# Seed solutions
113+
warm_start_configs = [
114+
{"min_weight_fraction_leaf": 0.02, "bootstrap": True, "max_depth": None, "n_estimators": 100},
115+
{"min_weight_fraction_leaf": 0.4, "bootstrap": True, "max_depth": 5, "n_estimators": 200},
116+
]
117+
111118
cv = StratifiedKFold(n_splits=3, shuffle=True)
112119
113120
evolved_estimator = GASearchCV(estimator=clf,
@@ -118,6 +125,8 @@ Example: Hyperparameters Tuning
118125
param_grid=param_grid,
119126
n_jobs=-1,
120127
verbose=True,
128+
use_cache=True,
129+
warm_start_configs=warm_start_configs,
121130
keep_top_k=4)
122131
123132
# Train and optimize the estimator

dev-requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ sphinx_rtd_theme
1515
sphinx-copybutton
1616
numpydoc
1717
nbsphinx
18+
ipython>=8.27.0
19+
Pygments>=2.18.0
1820
tensorflow>=2.4.0
1921
tqdm>=4.61.1
2022
tk

docs/release_notes.rst

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,30 @@ What's new in 0.11.0dev0
1010
Features:
1111
^^^^^^^^^
1212

13-
* Added a parameter named `use_cache`, defaults to `True`, If set to true it will avoid to re-evaluating solutions that have already seen,
14-
otherwise it will always evaluate the solutions to get the performance metrics
15-
13+
* Added a parameter `use_cache`, which defaults to ``True``. When enabled, the algorithm will skip re-evaluating solutions that have already been evaluated, retrieving the performance metrics from the cache instead.
14+
If use_cache is set to ``False``, the algorithm will always re-evaluate solutions, even if they have been seen before, to obtain fresh performance metrics.
15+
* Add a parameter in `GAFeatureSelectionCV` named warm_start_configs, defaults to ``None``, a list of predefined hyperparameter configurations to seed the initial population.
16+
Each element in the list is a dictionary where the keys are the names of the hyperparameters,
17+
and the values are the corresponding hyperparameter values to be used for the individual.
18+
19+
Example:
20+
21+
.. code-block:: python
22+
:linenos:
23+
24+
warm_start_configs = [
25+
{"min_weight_fraction_leaf": 0.02, "bootstrap": True, "max_depth": None, "n_estimators": 100},
26+
{"min_weight_fraction_leaf": 0.4, "bootstrap": True, "max_depth": 5, "n_estimators": 200},
27+
]
28+
29+
The genetic algorithm will initialize part of the population with these configurations to
30+
warm-start the optimization process. The remaining individuals in the population will
31+
be initialized randomly according to the defined hyperparameter space.
32+
33+
This parameter is useful when prior knowledge of good hyperparameter configurations exists,
34+
allowing the algorithm to focus on refining known good solutions while still exploring new
35+
areas of the hyperparameter space. If set to ``None``, the entire population will be initialized
36+
randomly.
1637

1738
What's new in 0.10.1
1839
--------------------

sklearn_genetic/genetic_search.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def __init__(
240240
return_train_score=False,
241241
log_config=None,
242242
use_cache=True,
243+
warm_start_configs=None,
243244
):
244245
self.estimator = estimator
245246
self.cv = cv
@@ -266,6 +267,7 @@ def __init__(
266267
self.log_config = log_config
267268
self.use_cache = use_cache
268269
self.fitness_cache = {}
270+
self.warm_start_configs = warm_start_configs or []
269271

270272
# Check that the estimator is compatible with scikit-learn
271273
if not is_classifier(self.estimator) and not is_regressor(self.estimator):
@@ -346,7 +348,7 @@ def _register(self):
346348

347349
self.toolbox.register("evaluate", self.evaluate)
348350

349-
self._pop = self.toolbox.population(n=self.population_size)
351+
self._pop = self._initialize_population()
350352
self._hof = tools.HallOfFame(self.keep_top_k)
351353

352354
self._stats = tools.Statistics(lambda ind: ind.fitness.values)
@@ -357,6 +359,29 @@ def _register(self):
357359

358360
self.logbook = tools.Logbook()
359361

362+
def _initialize_population(self):
363+
"""
364+
Initialize the population, using warm-start configurations if provided.
365+
"""
366+
population = []
367+
# Seed part of the population with warm-start values
368+
num_warm_start = min(len(self.warm_start_configs), self.population_size)
369+
370+
for config in self.warm_start_configs[:num_warm_start]:
371+
# Sample an individual from the warm-start configuration
372+
individual_values = self.space.sample_warm_start(config)
373+
individual_values_list = list(individual_values.values())
374+
375+
# Manually create the individual and assign its fitness
376+
individual = creator.Individual(individual_values_list)
377+
population.append(individual)
378+
379+
# Fill the remaining population with random individuals
380+
num_random = self.population_size - num_warm_start
381+
population.extend(self.toolbox.population(n=num_random))
382+
383+
return population
384+
360385
def mutate(self, individual):
361386
"""
362387
This function is responsible for change a randomly selected parameter from an individual

sklearn_genetic/space/space.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,27 @@ def __init__(self, param_grid: dict = None):
222222

223223
self.param_grid = param_grid
224224

225+
def sample_warm_start(self, warm_start_values: dict):
226+
"""
227+
Sample a predefined configuration (warm-start) or fill in random values if missing.
228+
229+
Parameters
230+
----------
231+
warm_start_values: dict
232+
Predefined configuration values for hyperparameters.
233+
234+
Returns
235+
-------
236+
A dictionary containing sampled values for each hyperparameter.
237+
"""
238+
sampled_params = {}
239+
for param, dimension in self.param_grid.items():
240+
if param in warm_start_values:
241+
sampled_params[param] = warm_start_values[param]
242+
else:
243+
sampled_params[param] = dimension.sample() # Random sample if no warm-start value
244+
return sampled_params
245+
225246
@property
226247
def dimensions(self):
227248
"""

0 commit comments

Comments
 (0)