diff --git a/keras_export/generate_test_models.py b/keras_export/generate_test_models.py index 018a2124..d8eb14c3 100644 --- a/keras_export/generate_test_models.py +++ b/keras_export/generate_test_models.py @@ -463,6 +463,12 @@ def get_test_model_exhaustive(): outputs.append(MultiHeadAttention( num_heads=1, key_dim=1, value_dim=None, use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51])) + outputs.append(MultiHeadAttention( + num_heads=2, key_dim=3, value_dim=5, + use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51])) + outputs.append(MultiHeadAttention( + num_heads=2, key_dim=3, value_dim=5, + use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51])) shared_conv = Conv2D(1, (1, 1), padding='valid', name='shared_conv', activation='relu')