diff --git a/keras_export/generate_test_models.py b/keras_export/generate_test_models.py index 0d6c35d4..809168f6 100644 --- a/keras_export/generate_test_models.py +++ b/keras_export/generate_test_models.py @@ -436,15 +436,25 @@ def get_test_model_exhaustive(): outputs.append(AdditiveAttention(use_scale=True)([inputs[49], inputs[50]])) outputs.append(AdditiveAttention(use_scale=True)([inputs[49], inputs[50], inputs[51]])) + outputs.append(MultiHeadAttention( + num_heads=1, key_dim=1, value_dim=None, + use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) outputs.append(MultiHeadAttention( num_heads=1, key_dim=1, value_dim=None, use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) outputs.append(MultiHeadAttention( - num_heads=3, key_dim=2, value_dim=None, - use_bias=False, output_shape=None, attention_axes=(2, 3))(inputs[2], inputs[3])) + num_heads=3, key_dim=1, value_dim=None, + use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) + # todo: re-enable + #outputs.append(MultiHeadAttention( + # num_heads=1, key_dim=2, value_dim=None, + # use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) + #outputs.append(MultiHeadAttention( + # num_heads=1, key_dim=1, value_dim=2, + # use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) outputs.append(MultiHeadAttention( - num_heads=3, key_dim=2, value_dim=1, - use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51])) + num_heads=1, key_dim=1, value_dim=None, + use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51])) shared_conv = Conv2D(1, (1, 1), padding='valid', name='shared_conv', activation='relu')