diff --git a/src/include/migraphx/op/group_query_attention.hpp b/src/include/migraphx/op/group_query_attention.hpp index 81cd464bbf..60246f4210 100644 --- a/src/include/migraphx/op/group_query_attention.hpp +++ b/src/include/migraphx/op/group_query_attention.hpp @@ -42,7 +42,6 @@ struct group_query_attention std::size_t num_heads = 1; bool rotary_interleaved = false; float scale = 1.0; - std::size_t present_kv_seqlen = 4096; template static auto reflect(Self& self, F f) @@ -52,8 +51,7 @@ struct group_query_attention f(self.local_window_size, "local_window_size"), f(self.num_heads, "num_heads"), f(self.rotary_interleaved, "rotary_interleaved"), - f(self.scale, "scale"), - f(self.present_kv_seqlen, "present_kv_seqlen")); + f(self.scale, "scale")); } std::string name() const { return "group_query_attention"; } @@ -438,7 +436,7 @@ struct group_query_attention argument present_k_out{kv_shape}; argument present_v_out{kv_shape}; argument attention_probs{shape{ - output_shape_0.type(), {batch_size, num_heads, sequence_length, present_kv_seqlen}}}; + output_shape_0.type(), {batch_size, num_heads, sequence_length, past_sequence_length}}}; args[0] = args[0].reshape( shape{output_shape_0.type(), @@ -507,7 +505,7 @@ struct group_query_attention gqa_params.head_stride = head_stride; gqa_params.batch_stride = batch_stride; gqa_params.position_ids_use_batch = position_ids_use_batch; - gqa_params.seqlen_present_kv_cache = present_kv_seqlen; + gqa_params.seqlen_present_kv_cache = past_sequence_length; gqa_params.past_present_share_buffer = false; if(do_rotary) diff --git a/src/onnx/parse_group_query_attention.cpp b/src/onnx/parse_group_query_attention.cpp index af78fd040a..7c36c252e0 100644 --- a/src/onnx/parse_group_query_attention.cpp +++ b/src/onnx/parse_group_query_attention.cpp @@ -37,7 +37,7 @@ struct parse_group_query_attention : op_parser std::vector parse(const op_desc& /*opd*/, const onnx_parser& parser, const onnx_parser::node_info& info, - std::vector args) const + const std::vector& args) const { bool do_rotary = false; std::size_t kv_num_heads = 0; @@ -77,15 +77,13 @@ struct parse_group_query_attention : op_parser MIGRAPHX_THROW("GroupQueryAttention: Wrong number of inputs provided"); } - auto present_kv_seqlen = args.at(args.size() - 6)->get_shape().lens()[2]; - auto gqa = info.add_instruction(make_op("group_query_attention", - {{"do_rotary", do_rotary}, - {"kv_num_heads", kv_num_heads}, - {"local_window_size", local_window_size}, - {"num_heads", num_heads}, - {"rotary_interleaved", rotary_interleaved}, - {"scale", scale}, - {"present_kv_seqlen", present_kv_seqlen}}), + auto gqa = info.add_instruction(make_op("group_query_attention", + {{"do_rotary", do_rotary}, + {"kv_num_heads", kv_num_heads}, + {"local_window_size", local_window_size}, + {"num_heads", num_heads}, + {"rotary_interleaved", rotary_interleaved}, + {"scale", scale}}), args); auto gqa_output = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), gqa); auto gqa_present_key = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), gqa); diff --git a/src/targets/gpu/include/migraphx/gpu/group_query_attention.hpp b/src/targets/gpu/include/migraphx/gpu/group_query_attention.hpp index 58cc1623fc..b7690c4bff 100644 --- a/src/targets/gpu/include/migraphx/gpu/group_query_attention.hpp +++ b/src/targets/gpu/include/migraphx/gpu/group_query_attention.hpp @@ -90,7 +90,7 @@ static inline gqa_parameters init_params(const std::vector& inputs, const auto local_window_size = v.at("local_window_size").to(); auto rotary_interleaved = v.at("rotary_interleaved").to(); auto scale = v.at("scale").to(); - auto present_kv_seqlen = v.at("present_kv_seqlen").to(); + auto present_kv_seqlen = inputs[1].lens().size() == 4 ? inputs[1].lens()[2] : 0; const auto& q_shape = inputs[0]; auto q_lens = q_shape.lens(); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/concat_past_present.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/concat_past_present.hpp index 9cf3526743..dcfebbbe0d 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/concat_past_present.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/concat_past_present.hpp @@ -114,21 +114,9 @@ update_cache(const Present present, SeqLensK seqlens_k, Cache cache, Params para } } -// template -// __device__ void -// concat_past_present(Output, const Query query, PastKey past_key, PastValue past_value, SeqLensK -// seqlens_k, Params params) -template -__device__ void concat_past_present(const Query query, - PastKey past_key, - PastValue past_value, - SeqLensK seqlens_k, - /* const Output, */ Params params) +template +__device__ void concat_past_present( + const Query query, PastKey past_key, PastValue past_value, SeqLensK seqlens_k, Params params) { auto ind = make_index(); auto elements = diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/gqa_softmax.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/gqa_softmax.hpp index 0ad5767e6f..27e2154b6f 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/gqa_softmax.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/gqa_softmax.hpp @@ -123,8 +123,9 @@ __device__ void calculate_softmax(AttnProbs attention_probs, // output buffer wi } } -template -__device__ void gqa_softmax(Output output, Input, Probs, SeqLensK seqlens_k, Params params) +template +__device__ void +gqa_softmax(Output output, Input, PresentKey, Probs, SeqLensK seqlens_k, Params params) { const index_int elements = params.batch_size * params.num_heads * params.sequence_length; auto ind = make_index(); diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 0c52d8c1ad..94df8db185 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -554,7 +554,7 @@ struct miopen_apply auto inputs = ins->inputs(); auto new_inputs = ins->inputs(); - new_inputs.push_back(inputs.at(1)); + new_inputs.push_back(inputs.at(2)); return mod->replace_instruction( ins, make_op("gpu::precompile_op", {{"op", to_value(ins->get_operator())}}), diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index d655b93f45..f8a8f83759 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -280,7 +280,7 @@ struct gpu_gqa_softmax : op::group_query_attention { std::string name() const { return "gpu::gqa_softmax"; } - shape compute_shape(std::vector inputs) const { return inputs.at(1); } + shape compute_shape(std::vector inputs) const { return inputs.at(2); } }; MIGRAPHX_REGISTER_OP(gpu_gqa_softmax); @@ -308,7 +308,6 @@ struct find_group_query_attention auto local_window_size = v.at("local_window_size").to(); auto rotary_interleaved = v.at("rotary_interleaved").to(); auto scale = v.at("scale").to(); - auto present_kv_seqlen = v.at("present_kv_seqlen").to(); auto q_shape = inputs[0]->get_shape(); auto q_lens = q_shape.lens(); @@ -338,8 +337,7 @@ struct find_group_query_attention local_window_size, num_heads, rotary_interleaved, - scale, - present_kv_seqlen}, + scale}, rotary_inputs); } @@ -347,41 +345,27 @@ struct find_group_query_attention auto pres_v = inputs.at(4); std::vector concat_inputs{rotary_qkv, pres_k, pres_v, inputs.at(5)}; - auto concat = - mpm.get_module().insert_instruction(ins, - gpu_concat_past_present{do_rotary, - kv_num_heads, - local_window_size, - num_heads, - rotary_interleaved, - scale, - present_kv_seqlen}, - concat_inputs); + auto concat = mpm.get_module().insert_instruction( + ins, + gpu_concat_past_present{ + do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale}, + concat_inputs); auto id = mpm.get_module().insert_instruction(ins, make_op("identity"), concat, pres_k, pres_v); std::vector attn_probs_inputs{id, pres_k, pres_v, inputs.at(5)}; auto attn_probs = mpm.get_module().insert_instruction( ins, - gpu_compute_attention_probabilities{do_rotary, - kv_num_heads, - local_window_size, - num_heads, - rotary_interleaved, - scale, - present_kv_seqlen}, + gpu_compute_attention_probabilities{ + do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale}, attn_probs_inputs); - std::vector softmax_inputs{rotary_qkv, attn_probs, inputs.at(5)}; - auto softmax = mpm.get_module().insert_instruction(ins, - gpu_gqa_softmax{do_rotary, - kv_num_heads, - local_window_size, - num_heads, - rotary_interleaved, - scale, - present_kv_seqlen}, - softmax_inputs); + std::vector softmax_inputs{rotary_qkv, pres_k, attn_probs, inputs.at(5)}; + auto softmax = mpm.get_module().insert_instruction( + ins, + gpu_gqa_softmax{ + do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale}, + softmax_inputs); std::vector new_inputs{rotary_qkv, pres_k, pres_v, inputs.at(5), softmax}; auto get_tuple_elm_0 = std::next(ins); @@ -389,15 +373,11 @@ struct find_group_query_attention auto get_tuple_elm_2 = std::next(get_tuple_elm_1); mpm.get_module().replace_instruction(get_tuple_elm_2, pres_v); mpm.get_module().replace_instruction(get_tuple_elm_1, pres_k); - mpm.get_module().replace_instruction(get_tuple_elm_0, - gpu_compute_attention_scores{do_rotary, - kv_num_heads, - local_window_size, - num_heads, - rotary_interleaved, - scale, - present_kv_seqlen}, - new_inputs); + mpm.get_module().replace_instruction( + get_tuple_elm_0, + gpu_compute_attention_scores{ + do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale}, + new_inputs); } }; diff --git a/test/verify/test_group_query_attention_gen.cpp b/test/verify/test_group_query_attention_gen.cpp index d5624db653..86badc5acc 100644 --- a/test/verify/test_group_query_attention_gen.cpp +++ b/test/verify/test_group_query_attention_gen.cpp @@ -34,7 +34,7 @@ struct test_group_query_attention_gen : verify_program query_lens{1, 1, 12288}; - std::vector kv_lens{1, 32, 4096, 128}; + std::vector kv_lens{1, 32, 2048, 128}; std::vector slk_lens{1, 1}; std::vector tsl_lens{1, 1}; std::vector cs_cache_lens{4096, 64};