diff --git a/keras_export/convert_model.py b/keras_export/convert_model.py index 91bf7073..66fa9259 100755 --- a/keras_export/convert_model.py +++ b/keras_export/convert_model.py @@ -566,6 +566,7 @@ def show_multi_head_attention_layer(layer): assert len(layer.input_shape) == 3 assert layer.input_shape[0] is None assert layer._output_shape is None + assert layer._attention_axes == (1,), "MultiHeadAttention supported only with attention_axes=None" return { 'weight_shapes': list(map(lambda w: list(w.shape), layer.weights)), 'weights': list(map(lambda w: encode_floats(w.numpy()), layer.weights)),