diff --git a/keras_export/convert_model.py b/keras_export/convert_model.py index 10591364..c9dcb5d9 100755 --- a/keras_export/convert_model.py +++ b/keras_export/convert_model.py @@ -652,14 +652,13 @@ class CopiedLayer: for attr in attributes: try: - if attr not in ['input_shape', '__class__']: + if attr not in ['batch_shape', '__class__']: setattr(copied_layer, attr, getattr(layer.layer, attr)) - elif attr == 'input_shape': - setattr(copied_layer, 'input_shape', input_shape_new) except Exception: continue - setattr(copied_layer, "output_shape", getattr(layer, "output_shape")) + setattr(copied_layer, 'batch_shape', input_shape_new) + setattr(copied_layer, "output_shape", layer.output.shape) return layer_function(copied_layer) @@ -711,7 +710,7 @@ def get_layer_weights(layer, name): result[name]['td_input_len'] = encode_floats( np.array([len(get_layer_input_shape(layer)) - 1], dtype=np.float32)) - result[name]['td_output_len'] = encode_floats(np.array([len(layer.output_shape) - 1], dtype=np.float32)) + result[name]['td_output_len'] = encode_floats(np.array([len(layer.output.shape) - 1], dtype=np.float32)) return result