Skip to content

Commit dcb75e2

Browse files
committed
Make kerasga.predict faster by not cloning the model
predict() cloned the whole model on every call and then used model.predict(). When predict() runs once per solution inside a fitness function, this is slow. It now sets the solution as the model weights, restores the original weights at the end, and calls the model directly when no batching is asked for. The output is the same and it runs much faster. model.predict() is still used when batch_size or steps is given.
1 parent 4f64b4e commit dcb75e2

1 file changed

Lines changed: 20 additions & 8 deletions

File tree

pygad/kerasga/kerasga.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def predict(model,
7979
steps=None):
8080
"""
8181
Load the given solution as the model's weights and run a forward
82-
pass on ``data``. The model is cloned first so the caller's
83-
instance is left untouched.
82+
pass on ``data``. The model's original weights are restored
83+
afterwards, so the model passed by the caller is not changed.
8484
8585
Parameters
8686
----------
@@ -105,12 +105,24 @@ def predict(model,
105105
# Fetch the parameters of the best solution.
106106
solution_weights = model_weights_as_matrix(model=model,
107107
weights_vector=solution)
108-
_model = tensorflow.keras.models.clone_model(model)
109-
_model.set_weights(solution_weights)
110-
predictions = _model.predict(x=data,
111-
batch_size=batch_size,
112-
verbose=verbose,
113-
steps=steps)
108+
109+
# Set the solution as the model weights, then put the original weights
110+
# back at the end. The model is not cloned because cloning it on every
111+
# call is slow when predict() is used inside a fitness function.
112+
original_weights = model.get_weights()
113+
model.set_weights(solution_weights)
114+
try:
115+
if batch_size is None and steps is None:
116+
# When no batching is asked for, call the model directly. This
117+
# is faster than model.predict() when called once per solution.
118+
predictions = numpy.array(model(data, training=False))
119+
else:
120+
predictions = model.predict(x=data,
121+
batch_size=batch_size,
122+
verbose=verbose,
123+
steps=steps)
124+
finally:
125+
model.set_weights(original_weights)
114126

115127
return predictions
116128

0 commit comments

Comments
 (0)