Skip to content

Commit

Permalink
adjust tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Dec 29, 2023
1 parent 53f2d9b commit 4184515
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions keras_export/generate_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,15 +436,25 @@ def get_test_model_exhaustive():
outputs.append(AdditiveAttention(use_scale=True)([inputs[49], inputs[50]]))
outputs.append(AdditiveAttention(use_scale=True)([inputs[49], inputs[50], inputs[51]]))

outputs.append(MultiHeadAttention(
num_heads=1, key_dim=1, value_dim=None,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=1, value_dim=None,
use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=3, key_dim=2, value_dim=None,
use_bias=False, output_shape=None, attention_axes=(2, 3))(inputs[2], inputs[3]))
num_heads=3, key_dim=1, value_dim=None,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
# todo: re-enable
#outputs.append(MultiHeadAttention(
# num_heads=1, key_dim=2, value_dim=None,
# use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
#outputs.append(MultiHeadAttention(
# num_heads=1, key_dim=1, value_dim=2,
# use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=3, key_dim=2, value_dim=1,
use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51]))
num_heads=1, key_dim=1, value_dim=None,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51]))

shared_conv = Conv2D(1, (1, 1),
padding='valid', name='shared_conv', activation='relu')
Expand Down

0 comments on commit 4184515

Please sign in to comment.