From 3714bae524e01802441be2addae22653635cb5d5 Mon Sep 17 00:00:00 2001 From: Dobiasd Date: Wed, 10 Apr 2024 09:43:21 +0200 Subject: [PATCH] simpler --- keras_export/convert_model.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/keras_export/convert_model.py b/keras_export/convert_model.py index a3a95584..a3980e94 100755 --- a/keras_export/convert_model.py +++ b/keras_export/convert_model.py @@ -121,7 +121,7 @@ def replace_none_with(value, shape): def get_first_outbound_op(layer): - """determine outbound operation""" + """Determine primary outbound operation""" return layer._outbound_nodes[0].operation @@ -138,7 +138,7 @@ def embedding_layer_names(model): if isinstance(layer, Embedding): result.add(layer.name) layer_type = type(layer).__name__ - if layer_type in ['model', 'sequential', 'functional', 'Model', 'Sequential', 'Functional']: + if layer_type in ['Model', 'Sequential', 'Functional']: result.union(embedding_layer_names(layer)) return result @@ -728,10 +728,9 @@ def get_all_weights(model, prefix): assert is_ascii(name) if name in result: raise ValueError('duplicate layer name ' + name) - if layer_type in ['model', 'sequential', 'functional', 'Model', 'Sequential', 'Functional']: + if layer_type in ['Model', 'Sequential', 'Functional']: result = merge_two_disjunct_dicts(result, get_all_weights(layer, name + '_')) - elif layer_type in ['TimeDistributed'] and type(layer.layer).__name__ in ['model', 'sequential', 'functional', - 'Model', 'Sequential', 'Functional']: + elif layer_type in ['TimeDistributed'] and type(layer.layer).__name__ in ['Model', 'Sequential', 'Functional']: inner_layer = layer.layer result = merge_two_disjunct_dicts(result, get_layer_weights(layer, name)) result = merge_two_disjunct_dicts(result, get_all_weights(inner_layer, name + "_")) @@ -762,7 +761,7 @@ def convert_sequential_to_model(model): model._inbound_nodes = inbound_nodes if type(model).__name__ == 'TimeDistributed': model.layer = convert_sequential_to_model(model.layer) - if type(model).__name__ in ['model', 'functional', 'Model', 'Functional']: + if type(model).__name__ in ['Model', 'Functional']: for i in range(len(model.layers)): new_layer = convert_sequential_to_model(model.layers[i]) if new_layer == model.layers[i]: