GroupQueryAttention strange behaviour depending on seq_len/total_seq_len #22732
Unanswered
ManelSemidynamics
asked this question in
API Q&A
Replies: 1 comment 1 reply
-
I think this behaviour due to the CPU provider using a share_buffer for past_{key/value} and present_{key/value}. Is there any way to disable it, so it prints the proper tensor as output? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I am implementing my own version of GQA and I found out that if seq_len != 1, total_seq_len != 2, the output "present_key" starting values are not the same as "past_key". Why is that? I expected it to be the concatenation (taking into account tensor formats) of "key" (may need to apply RoPe before concatenation) and "past_key".(I found that seq_len=1, total_seq_len=2 gives the output I expect)
Code to reproduce:
`import onnx
from onnx import helper, TensorProto
from collections import OrderedDict
kv_nh = 2
nh = 4
max_seq_len = 32
inputs = OrderedDict([
('query', helper.make_tensor_value_info('query', TensorProto.FLOAT16, ['batch_size', 'sequence_length', nh128])),
('key', helper.make_tensor_value_info('key', TensorProto.FLOAT16, ['batch_size', 'sequence_length', kv_nh128])),
('value', helper.make_tensor_value_info('value', TensorProto.FLOAT16, ['batch_size', 'sequence_length', kv_nh*128])),
('past_key', helper.make_tensor_value_info('past_key', TensorProto.FLOAT16, ['batch_size', kv_nh, 'past_sequence_length', 128])),
('past_value', helper.make_tensor_value_info('past_value', TensorProto.FLOAT16, ['batch_size', kv_nh, 'past_sequence_length', 128])),
('seqlens_k', helper.make_tensor_value_info('seqlens_k', TensorProto.INT32, ['batch_size', 1])),
('total_sequence_length', helper.make_tensor_value_info('total_sequence_length', TensorProto.INT32, [])),
('cos_cache', helper.make_tensor_value_info('cos_cache', TensorProto.FLOAT16, [max_seq_len, 64])),
('sin_cache', helper.make_tensor_value_info('sin_cache', TensorProto.FLOAT16, [max_seq_len, 64]))
])
outputs = OrderedDict([
('output', helper.make_tensor_value_info('output', TensorProto.FLOAT16, ['batch_size', 'sequence_length', nh*128])),
('present_key', helper.make_tensor_value_info('present_key', TensorProto.FLOAT16, ['batch_size', kv_nh, 'total_sequence_length', 128])),
('present_value', helper.make_tensor_value_info('present_value', TensorProto.FLOAT16, ['batch_size', kv_nh, 'total_sequence_length', 128]))
])
node = helper.make_node(
'GroupQueryAttention',
inputs=list(inputs.keys()),
outputs=list(outputs.keys()),
name='GroupQueryAttention_Node',
domain='com.microsoft',
do_rotary=1,
kv_num_heads=kv_nh,
num_heads=nh,
rotary_interleaved=1,
scale=0.0888383461536163
)
graph = helper.make_graph(
[node],
'GroupQueryAttentionGraph',
list(inputs.values()),
list(outputs.values())
)
model = helper.make_model(
graph,
producer_name='onnx-helper',
opset_imports=[
helper.make_opsetid("", 19),
helper.make_opsetid("com.microsoft", 1)
]
)
onnx.checker.check_model(model)
onnx.save(model, 'group_query_attention.onnx')
import onnxruntime as ort
import numpy as np
session = ort.InferenceSession('/nfs/workspaces/mpiera/GQA/group_query_attention.onnx')
batch = 1
seq_len = 2
tot_seq_len = 4
past = tot_seq_len - seq_len
query = np.random.rand(batch,seq_len,nh128).astype(np.float16)
key = np.random.rand(batch,seq_len,kv_nh128).astype(np.float16)
value = np.random.rand(batch,seq_len,kv_nh*128).astype(np.float16)
past_key = np.random.rand(batch,kv_nh,past,128).astype(np.float16)
past_value = np.random.rand(batch,kv_nh,past,128).astype(np.float16)
seqlens_k = np.array([[seq_len]], dtype=np.int32)
total_sequence_length = np.array(tot_seq_len, dtype=np.int32)
cos_cache = np.random.rand(max_seq_len, 64).astype(np.float16)
sin_cache = np.random.rand(max_seq_len, 64).astype(np.float16)
inputs = {
'query': query,
'key': key,
'value': value,
'past_key': past_key,
'past_value': past_value,
'seqlens_k': seqlens_k,
'total_sequence_length': total_sequence_length,
'cos_cache': cos_cache,
'sin_cache': sin_cache,
}
outputs = session.run(None, inputs)
print("inputs")
print(inputs)
print("outputs")
print(outputs)`
Beta Was this translation helpful? Give feedback.
All reactions