Skip to content

Commit

Permalink
Add parse tests for multi head attention
Browse files Browse the repository at this point in the history
  • Loading branch information
marko-fabo-htec committed Dec 18, 2024
1 parent 3394d29 commit 3fbb3fc
Show file tree
Hide file tree
Showing 17 changed files with 701 additions and 0 deletions.
153 changes: 153 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7864,6 +7864,159 @@ def mha_scale_test():
return ([node], [query, key, value], [out])


@onnx_test()
def mha_invalid_attribute_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 2, 4])
key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 2, 4])
value = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 2, 4])
out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 2, 4])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'k', 'v'],
outputs=['out'],
domain='com.microsoft')

return ([node], [query, key, value], [out])


@onnx_test()
def mha_invalid_input_test():
node = helper.make_node('MultiHeadAttention',
inputs=[],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [], [])


@onnx_test()
def mha_invalid_query_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 1])

node = helper.make_node('MultiHeadAttention',
inputs=['q'],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [query], [])


@onnx_test()
def mha_invalid_qkv_test():
qkv = helper.make_tensor_value_info("qkv", TensorProto.FLOAT,
[1, 1, 1, 1, 1])

node = helper.make_node('MultiHeadAttention',
inputs=['qkv'],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [qkv], [])


@onnx_test()
def mha_invalid_key_missing_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 1, 1])

node = helper.make_node('MultiHeadAttention',
inputs=['q'],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [query], [])


@onnx_test()
def mha_invalid_key_ndim_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 1, 1])
key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 1])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'k'],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [query, key], [])


@onnx_test()
def mha_invalid_kv_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 1, 1])
kv = helper.make_tensor_value_info("kv", TensorProto.FLOAT, [1, 1, 1, 1, 1])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'kv'],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [query, kv], [])


@onnx_test()
def mha_invalid_key_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 1, 1])
key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 1, 2])
value = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 1, 1])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'k', 'v'],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [query, key, value], [])


@onnx_test()
def mha_invalid_value_missing_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 1, 1])
key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 1, 1])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'k'],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [query, key], [])


@onnx_test()
def mha_invalid_value_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 1, 1])
key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 1, 1])
value = helper.make_tensor_value_info("v", TensorProto.FLOAT, [2, 1, 1])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'k', 'v'],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [query, key, value], [])


@onnx_test()
def mha_invalid_value_ndim_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 1, 1])
key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 1, 1])
value = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 1, 1, 1])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'k', 'v'],
outputs=[],
num_heads=1,
domain='com.microsoft')

return ([node], [query, key, value], [])


@onnx_test()
def multinomial_test():
sample_size = 13
Expand Down
25 changes: 25 additions & 0 deletions test/onnx/mha_invalid_attribute_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
 mha_invalid_attribute_test:�
1
q
k
vout"MultiHeadAttention:com.microsoftmha_invalid_attribute_testZ
q



Z
k



Z
v



b
out



B
Expand Down
3 changes: 3 additions & 0 deletions test/onnx/mha_invalid_input_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
 mha_invalid_input_test:O
5"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_input_testB
Expand Down
9 changes: 9 additions & 0 deletions test/onnx/mha_invalid_key_missing_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
 mha_invalid_key_missing_test:q
8
q"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_key_missing_testZ
q



B
Expand Down
14 changes: 14 additions & 0 deletions test/onnx/mha_invalid_key_ndim_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 mha_invalid_key_ndim_test:�
;
q
k"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_key_ndim_testZ
q



Z
k


B
Expand Down
21 changes: 21 additions & 0 deletions test/onnx/mha_invalid_key_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
 mha_invalid_key_test:�
>
q
k
v"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_key_testZ
q



Z
k



Z
v



B
Expand Down
17 changes: 17 additions & 0 deletions test/onnx/mha_invalid_kv_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
 mha_invalid_kv_test:�
<
q
kv"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_kv_testZ
q



Z
kv





B
Expand Down
11 changes: 11 additions & 0 deletions test/onnx/mha_invalid_qkv_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
 mha_invalid_qkv_test:u
:
qkv"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_qkv_testZ!
qkv





B
Expand Down
8 changes: 8 additions & 0 deletions test/onnx/mha_invalid_query_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
 mha_invalid_query_test:g
8
q"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_query_testZ
q


B
Expand Down
15 changes: 15 additions & 0 deletions test/onnx/mha_invalid_value_missing_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
 mha_invalid_value_missing_test:�
;
q
k"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_value_missing_testZ
q



Z
k



B
Expand Down
22 changes: 22 additions & 0 deletions test/onnx/mha_invalid_value_ndim_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
 mha_invalid_value_ndim_test:�
>
q
k
v"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_value_ndim_testZ
q



Z
k



Z
v




B
Expand Down
21 changes: 21 additions & 0 deletions test/onnx/mha_invalid_value_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
 mha_invalid_value_test:�
>
q
k
v"MultiHeadAttention*
num_heads�:com.microsoftmha_invalid_value_testZ
q



Z
k



Z
v



B
Expand Down
Loading

0 comments on commit 3fbb3fc

Please sign in to comment.