Skip to content

Commit

Permalink
Assert model types from keras.src
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Mar 18, 2024
1 parent 306911a commit ba14eb0
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,9 @@ def workaround_cudnn_not_found_problem():
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)

def assert_model_type(model):
import keras
assert type(model) in [keras.src.models.sequential.Sequential, keras.src.models.functional.Functional]

def convert(in_path, out_path, no_tests=False):
"""Convert any (h5-)stored Keras model to the frugally-deep model format."""
Expand All @@ -895,7 +898,6 @@ def convert(in_path, out_path, no_tests=False):

print('loading {}'.format(in_path))
model = load_model(in_path, compile=False)
assert type(model) == tf.keras.models.Model
json_output = model_to_fdeep_json(model, no_tests)
print('writing {}'.format(out_path))

Expand Down

0 comments on commit ba14eb0

Please sign in to comment.