Skip to content

Commit

Permalink
re-add layer position checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Apr 16, 2024
1 parent 974a6bc commit 029a721
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,35 @@ def get_first_outbound_op(layer):
return layer._outbound_nodes[0].operation


def are_embedding_and_category_encoding_layer_positions_ok_for_testing(model):
"""
Test data can only be generated if all Embedding layers
and CategoryEncoding layers are positioned directly behind the input nodes.
"""

def embedding_layer_names(model):
layers = model.layers
result = set()
for layer in layers:
if isinstance(layer, Embedding):
result.add(layer.name)
layer_type = type(layer).__name__
if layer_type in ['Model', 'Sequential', 'Functional']:
result.union(embedding_layer_names(layer))
return result

def embedding_layer_names_at_input_nodes(model):
result = set()
for input_layer in get_model_input_layers(model):
if input_layer._outbound_nodes and (
isinstance(get_first_outbound_op(input_layer), Embedding) or
isinstance(get_first_outbound_op(input_layer), CategoryEncoding)):
result.add(get_first_outbound_op(input_layer).name)
return set(result)

return embedding_layer_names(model) == embedding_layer_names_at_input_nodes(model)


def gen_test_data(model):
"""Generate data for model verification test."""

Expand Down Expand Up @@ -118,6 +147,9 @@ def generate_input_data(input_layer):
return random_fn(
size=replace_none_with(32, set_shape_idx_0_to_1_if_none(singleton_list_to_value(shape)))).astype(np.float32)

assert are_embedding_and_category_encoding_layer_positions_ok_for_testing(
model), "Test data can only be generated if embedding layers are positioned directly after input nodes."

data_in = list(map(generate_input_data, get_model_input_layers(model)))

warm_up_runs = 3
Expand Down

0 comments on commit 029a721

Please sign in to comment.