Skip to content

Commit

Permalink
simplify further
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Apr 10, 2024
1 parent d2a850c commit 3ddcf8b
Showing 1 changed file with 7 additions and 21 deletions.
28 changes: 7 additions & 21 deletions keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def show_tensor(tens):


def get_model_input_layers(model):
"""Works for different Keras version."""
"""Gets the input layers from model.layers in the correct input order."""
if len(model.inputs) == 1:
from keras.src.layers.core.input_layer import InputLayer
input_layers = []
Expand Down Expand Up @@ -121,12 +121,8 @@ def replace_none_with(value, shape):


def get_first_outbound_op(layer):
first_node = layer._outbound_nodes[0]
if hasattr(first_node, "outbound_layer"):
return first_node.outbound_layer
if hasattr(first_node, "operation"):
return first_node.operation
raise ValueError("Can't determine outbound operation")
"""determine outbound operation"""
return layer._outbound_nodes[0].operation


def are_embedding_layer_positions_ok_for_testing(model):
Expand Down Expand Up @@ -745,35 +741,25 @@ def get_all_weights(model, prefix):


def get_model_name(model):
"""Return .name or ._name or 'dummy_model_name'"""
"""Return .name or ._name"""
if hasattr(model, 'name'):
return model.name
if hasattr(model, '_name'):
return model._name
return 'dummy_model_name'
return model._name


def convert_sequential_to_model(model):
"""Convert a sequential model to the underlying functional format"""
if type(model).__name__ in ['sequential', 'Sequential']:
name = get_model_name(model)
if hasattr(model, '_inbound_nodes'):
inbound_nodes = model._inbound_nodes
elif hasattr(model, 'inbound_nodes'):
inbound_nodes = model.inbound_nodes
else:
raise ValueError('can not get (_)inbound_nodes from model')
inbound_nodes = model._inbound_nodes
input_layer = Input(batch_shape=get_layer_input_shape(model.layers[0]))
prev_layer = input_layer
for layer in model.layers:
layer._inbound_nodes = []
prev_layer = layer(prev_layer)
funcmodel = Model([input_layer], [prev_layer], name=name)
model = funcmodel
if hasattr(model, '_inbound_nodes'):
model._inbound_nodes = inbound_nodes
elif hasattr(model, 'inbound_nodes'):
model.inbound_nodes = inbound_nodes
model._inbound_nodes = inbound_nodes
if type(model).__name__ == 'TimeDistributed':
model.layer = convert_sequential_to_model(model.layer)
if type(model).__name__ in ['model', 'functional', 'Model', 'Functional']:
Expand Down

0 comments on commit 3ddcf8b

Please sign in to comment.