diff --git a/keras_export/convert_model.py b/keras_export/convert_model.py index b6740d9b..996fa0b3 100755 --- a/keras_export/convert_model.py +++ b/keras_export/convert_model.py @@ -6,6 +6,7 @@ import datetime import hashlib import json +import os import sys import numpy as np @@ -596,10 +597,6 @@ def get_all_weights(model, prefix): layers = model.layers assert K.image_data_format() == 'channels_last' for layer in layers: - for node in layer.inbound_nodes: - if "training" in node.call_kwargs: - assert node.call_kwargs["training"] is not True, \ - "training=true is not supported, see https://github.com/Dobiasd/frugally-deep/issues/284" layer_type = type(layer).__name__ name = prefix + layer.name assert is_ascii(name) @@ -750,6 +747,32 @@ def singleton_list_to_value(value_or_values): return value_or_values +def remove_training_flags(model): + """ + In case a layer has set training=True, this flag is removed. + """ + layers = model.layers + for layer in layers: + for node in layer.inbound_nodes: + if "training" in node.call_kwargs: + print(f"Removing training=True from inbound node to layer named {layer.name}.") + del node.call_kwargs["training"] + layer_type = type(layer).__name__ + if layer_type in ['Model', 'Sequential', 'Functional']: + remove_training_flags(layer) + + +def remove_training_flags_and_reload(model, temp_model_filename): + """ + The reload is needed to propagate the change in the model. + """ + remove_training_flags(model) + model.save(temp_model_filename, include_optimizer=False) + model = load_model(temp_model_filename) + os.remove(temp_model_filename) + return model + + def model_to_fdeep_json(model, no_tests=False): """Convert any Keras model to the frugally-deep model format.""" @@ -788,6 +811,7 @@ def convert(in_path, out_path, no_tests=False): print('loading {}'.format(in_path)) model = load_model(in_path) + model = remove_training_flags_and_reload(model, out_path + ".fdeep_temp.h5") json_output = model_to_fdeep_json(model, no_tests) print('writing {}'.format(out_path)) write_text_file(out_path, json.dumps(