From edeeb877b70ff372bc1ef0ec68ed0508e4569bbd Mon Sep 17 00:00:00 2001 From: Alejandro Cid Delgado Date: Thu, 20 Feb 2025 10:59:17 -0800 Subject: [PATCH] check for sequence length --- .../contrib_ops/cpu/bert/multihead_attention_helper.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 0cfe90963c334..83fc572ebe3b5 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -359,6 +359,11 @@ Status CheckInputs(const T* query, ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, past_seq_len, batch_size, num_heads, head_size, past_present_share_buffer, past_sequence_length, max_sequence_length)); + + if (past_sequence_length > 1 && sequence_length > 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' is expected to have sequence_length == 1 when past_sequence_length > 1"); + } } else if (past_key != nullptr || past_value != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' and 'past_value' shall be both present or both absent");