Skip to content

Commit

Permalink
Fix GroupQueryAttention Max Sequence Length Propagation (#3539)
Browse files Browse the repository at this point in the history
  • Loading branch information
turneram authored Oct 21, 2024
1 parent c1d2d5a commit 8d14e88
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 75 deletions.
8 changes: 3 additions & 5 deletions src/include/migraphx/op/group_query_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ struct group_query_attention
std::size_t num_heads = 1;
bool rotary_interleaved = false;
float scale = 1.0;
std::size_t present_kv_seqlen = 4096;

template <class Self, class F>
static auto reflect(Self& self, F f)
Expand All @@ -52,8 +51,7 @@ struct group_query_attention
f(self.local_window_size, "local_window_size"),
f(self.num_heads, "num_heads"),
f(self.rotary_interleaved, "rotary_interleaved"),
f(self.scale, "scale"),
f(self.present_kv_seqlen, "present_kv_seqlen"));
f(self.scale, "scale"));
}

std::string name() const { return "group_query_attention"; }
Expand Down Expand Up @@ -438,7 +436,7 @@ struct group_query_attention
argument present_k_out{kv_shape};
argument present_v_out{kv_shape};
argument attention_probs{shape{
output_shape_0.type(), {batch_size, num_heads, sequence_length, present_kv_seqlen}}};
output_shape_0.type(), {batch_size, num_heads, sequence_length, past_sequence_length}}};

args[0] = args[0].reshape(
shape{output_shape_0.type(),
Expand Down Expand Up @@ -507,7 +505,7 @@ struct group_query_attention
gqa_params.head_stride = head_stride;
gqa_params.batch_stride = batch_stride;
gqa_params.position_ids_use_batch = position_ids_use_batch;
gqa_params.seqlen_present_kv_cache = present_kv_seqlen;
gqa_params.seqlen_present_kv_cache = past_sequence_length;
gqa_params.past_present_share_buffer = false;

if(do_rotary)
Expand Down
18 changes: 8 additions & 10 deletions src/onnx/parse_group_query_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct parse_group_query_attention : op_parser<parse_group_query_attention>
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
const std::vector<instruction_ref>& args) const
{
bool do_rotary = false;
std::size_t kv_num_heads = 0;
Expand Down Expand Up @@ -77,15 +77,13 @@ struct parse_group_query_attention : op_parser<parse_group_query_attention>
MIGRAPHX_THROW("GroupQueryAttention: Wrong number of inputs provided");
}

auto present_kv_seqlen = args.at(args.size() - 6)->get_shape().lens()[2];
auto gqa = info.add_instruction(make_op("group_query_attention",
{{"do_rotary", do_rotary},
{"kv_num_heads", kv_num_heads},
{"local_window_size", local_window_size},
{"num_heads", num_heads},
{"rotary_interleaved", rotary_interleaved},
{"scale", scale},
{"present_kv_seqlen", present_kv_seqlen}}),
auto gqa = info.add_instruction(make_op("group_query_attention",
{{"do_rotary", do_rotary},
{"kv_num_heads", kv_num_heads},
{"local_window_size", local_window_size},
{"num_heads", num_heads},
{"rotary_interleaved", rotary_interleaved},
{"scale", scale}}),
args);
auto gqa_output = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), gqa);
auto gqa_present_key = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), gqa);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static inline gqa_parameters init_params(const std::vector<shape>& inputs, const
auto local_window_size = v.at("local_window_size").to<std::uint32_t>();
auto rotary_interleaved = v.at("rotary_interleaved").to<bool>();
auto scale = v.at("scale").to<float>();
auto present_kv_seqlen = v.at("present_kv_seqlen").to<std::size_t>();
auto present_kv_seqlen = inputs[1].lens().size() == 4 ? inputs[1].lens()[2] : 0;

const auto& q_shape = inputs[0];
auto q_lens = q_shape.lens();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,9 @@ update_cache(const Present present, SeqLensK seqlens_k, Cache cache, Params para
}
}

// template <class Output, class Query, class PastKey, class PastValue, class SeqLensK, class
// Params>
// __device__ void
// concat_past_present(Output, const Query query, PastKey past_key, PastValue past_value, SeqLensK
// seqlens_k, Params params)
template <class Query,
class PastKey,
class PastValue,
class SeqLensK,
/* class Output, */ class Params>
__device__ void concat_past_present(const Query query,
PastKey past_key,
PastValue past_value,
SeqLensK seqlens_k,
/* const Output, */ Params params)
template <class Query, class PastKey, class PastValue, class SeqLensK, class Params>
__device__ void concat_past_present(
const Query query, PastKey past_key, PastValue past_value, SeqLensK seqlens_k, Params params)
{
auto ind = make_index();
auto elements =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ __device__ void calculate_softmax(AttnProbs attention_probs, // output buffer wi
}
}

template <class Output, class Input, class Probs, class SeqLensK, class Params>
__device__ void gqa_softmax(Output output, Input, Probs, SeqLensK seqlens_k, Params params)
template <class Output, class Input, class PresentKey, class Probs, class SeqLensK, class Params>
__device__ void
gqa_softmax(Output output, Input, PresentKey, Probs, SeqLensK seqlens_k, Params params)
{
const index_int elements = params.batch_size * params.num_heads * params.sequence_length;
auto ind = make_index();
Expand Down
2 changes: 1 addition & 1 deletion src/targets/gpu/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ struct miopen_apply
auto inputs = ins->inputs();

auto new_inputs = ins->inputs();
new_inputs.push_back(inputs.at(1));
new_inputs.push_back(inputs.at(2));
return mod->replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(ins->get_operator())}}),
Expand Down
60 changes: 20 additions & 40 deletions src/targets/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ struct gpu_gqa_softmax : op::group_query_attention
{
std::string name() const { return "gpu::gqa_softmax"; }

shape compute_shape(std::vector<shape> inputs) const { return inputs.at(1); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(2); }
};
MIGRAPHX_REGISTER_OP(gpu_gqa_softmax);

Expand Down Expand Up @@ -308,7 +308,6 @@ struct find_group_query_attention
auto local_window_size = v.at("local_window_size").to<int>();
auto rotary_interleaved = v.at("rotary_interleaved").to<bool>();
auto scale = v.at("scale").to<float>();
auto present_kv_seqlen = v.at("present_kv_seqlen").to<std::size_t>();

auto q_shape = inputs[0]->get_shape();
auto q_lens = q_shape.lens();
Expand Down Expand Up @@ -338,66 +337,47 @@ struct find_group_query_attention
local_window_size,
num_heads,
rotary_interleaved,
scale,
present_kv_seqlen},
scale},
rotary_inputs);
}

auto pres_k = inputs.at(3);
auto pres_v = inputs.at(4);
std::vector<instruction_ref> concat_inputs{rotary_qkv, pres_k, pres_v, inputs.at(5)};

auto concat =
mpm.get_module().insert_instruction(ins,
gpu_concat_past_present{do_rotary,
kv_num_heads,
local_window_size,
num_heads,
rotary_interleaved,
scale,
present_kv_seqlen},
concat_inputs);
auto concat = mpm.get_module().insert_instruction(
ins,
gpu_concat_past_present{
do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale},
concat_inputs);
auto id =
mpm.get_module().insert_instruction(ins, make_op("identity"), concat, pres_k, pres_v);

std::vector<instruction_ref> attn_probs_inputs{id, pres_k, pres_v, inputs.at(5)};
auto attn_probs = mpm.get_module().insert_instruction(
ins,
gpu_compute_attention_probabilities{do_rotary,
kv_num_heads,
local_window_size,
num_heads,
rotary_interleaved,
scale,
present_kv_seqlen},
gpu_compute_attention_probabilities{
do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale},
attn_probs_inputs);

std::vector<instruction_ref> softmax_inputs{rotary_qkv, attn_probs, inputs.at(5)};
auto softmax = mpm.get_module().insert_instruction(ins,
gpu_gqa_softmax{do_rotary,
kv_num_heads,
local_window_size,
num_heads,
rotary_interleaved,
scale,
present_kv_seqlen},
softmax_inputs);
std::vector<instruction_ref> softmax_inputs{rotary_qkv, pres_k, attn_probs, inputs.at(5)};
auto softmax = mpm.get_module().insert_instruction(
ins,
gpu_gqa_softmax{
do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale},
softmax_inputs);
std::vector<instruction_ref> new_inputs{rotary_qkv, pres_k, pres_v, inputs.at(5), softmax};

auto get_tuple_elm_0 = std::next(ins);
auto get_tuple_elm_1 = std::next(get_tuple_elm_0);
auto get_tuple_elm_2 = std::next(get_tuple_elm_1);
mpm.get_module().replace_instruction(get_tuple_elm_2, pres_v);
mpm.get_module().replace_instruction(get_tuple_elm_1, pres_k);
mpm.get_module().replace_instruction(get_tuple_elm_0,
gpu_compute_attention_scores{do_rotary,
kv_num_heads,
local_window_size,
num_heads,
rotary_interleaved,
scale,
present_kv_seqlen},
new_inputs);
mpm.get_module().replace_instruction(
get_tuple_elm_0,
gpu_compute_attention_scores{
do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale},
new_inputs);
}
};

Expand Down
2 changes: 1 addition & 1 deletion test/verify/test_group_query_attention_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct test_group_query_attention_gen : verify_program<test_group_query_attentio
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> query_lens{1, 1, 12288};
std::vector<size_t> kv_lens{1, 32, 4096, 128};
std::vector<size_t> kv_lens{1, 32, 2048, 128};
std::vector<size_t> slk_lens{1, 1};
std::vector<size_t> tsl_lens{1, 1};
std::vector<size_t> cs_cache_lens{4096, 64};
Expand Down

0 comments on commit 8d14e88

Please sign in to comment.