diff --git a/keras_export/convert_model.py b/keras_export/convert_model.py index 14cbd698..6209c99a 100755 --- a/keras_export/convert_model.py +++ b/keras_export/convert_model.py @@ -193,6 +193,8 @@ def show_separable_conv_2d_layer(layer): def show_batch_normalization_layer(layer): """Serialize batch normalization layer to dict""" + assert len(layer.axis) == 1 + assert layer.axis[0] == -1 or layer.axis[0] + 1 == len(layer.input_shape) moving_mean = K.get_value(layer.moving_mean) moving_variance = K.get_value(layer.moving_variance) result = {}