Skip to content

Commit

Permalink
Remove training=True from inbound nodes during conversion, fixes #284
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Jul 6, 2021
1 parent 6f9e4b1 commit 357d663
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datetime
import hashlib
import json
import os
import sys

import numpy as np
Expand Down Expand Up @@ -596,10 +597,6 @@ def get_all_weights(model, prefix):
layers = model.layers
assert K.image_data_format() == 'channels_last'
for layer in layers:
for node in layer.inbound_nodes:
if "training" in node.call_kwargs:
assert node.call_kwargs["training"] is not True, \
"training=true is not supported, see https://github.com/Dobiasd/frugally-deep/issues/284"
layer_type = type(layer).__name__
name = prefix + layer.name
assert is_ascii(name)
Expand Down Expand Up @@ -750,6 +747,32 @@ def singleton_list_to_value(value_or_values):
return value_or_values


def remove_training_flags(model):
"""
In case a layer has set training=True, this flag is removed.
"""
layers = model.layers
for layer in layers:
for node in layer.inbound_nodes:
if "training" in node.call_kwargs:
print(f"Removing training=True from inbound node to layer named {layer.name}.")
del node.call_kwargs["training"]
layer_type = type(layer).__name__
if layer_type in ['Model', 'Sequential', 'Functional']:
remove_training_flags(layer)


def remove_training_flags_and_reload(model, temp_model_filename):
"""
The reload is needed to propagate the change in the model.
"""
remove_training_flags(model)
model.save(temp_model_filename, include_optimizer=False)
model = load_model(temp_model_filename)
os.remove(temp_model_filename)
return model


def model_to_fdeep_json(model, no_tests=False):
"""Convert any Keras model to the frugally-deep model format."""

Expand Down Expand Up @@ -788,6 +811,7 @@ def convert(in_path, out_path, no_tests=False):

print('loading {}'.format(in_path))
model = load_model(in_path)
model = remove_training_flags_and_reload(model, out_path + ".fdeep_temp.h5")
json_output = model_to_fdeep_json(model, no_tests)
print('writing {}'.format(out_path))
write_text_file(out_path, json.dumps(
Expand Down

0 comments on commit 357d663

Please sign in to comment.