Skip to content

Commit

Permalink
simpler
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Apr 10, 2024
1 parent 3ddcf8b commit 3714bae
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def replace_none_with(value, shape):


def get_first_outbound_op(layer):
"""determine outbound operation"""
"""Determine primary outbound operation"""
return layer._outbound_nodes[0].operation


Expand All @@ -138,7 +138,7 @@ def embedding_layer_names(model):
if isinstance(layer, Embedding):
result.add(layer.name)
layer_type = type(layer).__name__
if layer_type in ['model', 'sequential', 'functional', 'Model', 'Sequential', 'Functional']:
if layer_type in ['Model', 'Sequential', 'Functional']:
result.union(embedding_layer_names(layer))
return result

Expand Down Expand Up @@ -728,10 +728,9 @@ def get_all_weights(model, prefix):
assert is_ascii(name)
if name in result:
raise ValueError('duplicate layer name ' + name)
if layer_type in ['model', 'sequential', 'functional', 'Model', 'Sequential', 'Functional']:
if layer_type in ['Model', 'Sequential', 'Functional']:
result = merge_two_disjunct_dicts(result, get_all_weights(layer, name + '_'))
elif layer_type in ['TimeDistributed'] and type(layer.layer).__name__ in ['model', 'sequential', 'functional',
'Model', 'Sequential', 'Functional']:
elif layer_type in ['TimeDistributed'] and type(layer.layer).__name__ in ['Model', 'Sequential', 'Functional']:
inner_layer = layer.layer
result = merge_two_disjunct_dicts(result, get_layer_weights(layer, name))
result = merge_two_disjunct_dicts(result, get_all_weights(inner_layer, name + "_"))
Expand Down Expand Up @@ -762,7 +761,7 @@ def convert_sequential_to_model(model):
model._inbound_nodes = inbound_nodes
if type(model).__name__ == 'TimeDistributed':
model.layer = convert_sequential_to_model(model.layer)
if type(model).__name__ in ['model', 'functional', 'Model', 'Functional']:
if type(model).__name__ in ['Model', 'Functional']:
for i in range(len(model.layers)):
new_layer = convert_sequential_to_model(model.layers[i])
if new_layer == model.layers[i]:
Expand Down

0 comments on commit 3714bae

Please sign in to comment.