Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GroupQueryAttention Max Sequence Length Propagation #3539

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/include/migraphx/op/group_query_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ struct group_query_attention
std::size_t num_heads = 1;
bool rotary_interleaved = false;
float scale = 1.0;
std::size_t present_kv_seqlen = 4096;

template <class Self, class F>
static auto reflect(Self& self, F f)
Expand All @@ -52,8 +51,7 @@ struct group_query_attention
f(self.local_window_size, "local_window_size"),
f(self.num_heads, "num_heads"),
f(self.rotary_interleaved, "rotary_interleaved"),
f(self.scale, "scale"),
f(self.present_kv_seqlen, "present_kv_seqlen"));
f(self.scale, "scale"));
}

std::string name() const { return "group_query_attention"; }
Expand Down Expand Up @@ -438,7 +436,7 @@ struct group_query_attention
argument present_k_out{kv_shape};
argument present_v_out{kv_shape};
argument attention_probs{shape{
output_shape_0.type(), {batch_size, num_heads, sequence_length, present_kv_seqlen}}};
output_shape_0.type(), {batch_size, num_heads, sequence_length, past_sequence_length}}};

args[0] = args[0].reshape(
shape{output_shape_0.type(),
Expand Down Expand Up @@ -507,7 +505,7 @@ struct group_query_attention
gqa_params.head_stride = head_stride;
gqa_params.batch_stride = batch_stride;
gqa_params.position_ids_use_batch = position_ids_use_batch;
gqa_params.seqlen_present_kv_cache = present_kv_seqlen;
gqa_params.seqlen_present_kv_cache = past_sequence_length;
gqa_params.past_present_share_buffer = false;

if(do_rotary)
Expand Down
18 changes: 8 additions & 10 deletions src/onnx/parse_group_query_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct parse_group_query_attention : op_parser<parse_group_query_attention>
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
const std::vector<instruction_ref>& args) const
{
bool do_rotary = false;
std::size_t kv_num_heads = 0;
Expand Down Expand Up @@ -77,15 +77,13 @@ struct parse_group_query_attention : op_parser<parse_group_query_attention>
MIGRAPHX_THROW("GroupQueryAttention: Wrong number of inputs provided");
}

auto present_kv_seqlen = args.at(args.size() - 6)->get_shape().lens()[2];
auto gqa = info.add_instruction(make_op("group_query_attention",
{{"do_rotary", do_rotary},
{"kv_num_heads", kv_num_heads},
{"local_window_size", local_window_size},
{"num_heads", num_heads},
{"rotary_interleaved", rotary_interleaved},
{"scale", scale},
{"present_kv_seqlen", present_kv_seqlen}}),
auto gqa = info.add_instruction(make_op("group_query_attention",
{{"do_rotary", do_rotary},
{"kv_num_heads", kv_num_heads},
{"local_window_size", local_window_size},
{"num_heads", num_heads},
{"rotary_interleaved", rotary_interleaved},
{"scale", scale}}),
args);
auto gqa_output = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), gqa);
auto gqa_present_key = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), gqa);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static inline gqa_parameters init_params(const std::vector<shape>& inputs, const
auto local_window_size = v.at("local_window_size").to<std::uint32_t>();
auto rotary_interleaved = v.at("rotary_interleaved").to<bool>();
auto scale = v.at("scale").to<float>();
auto present_kv_seqlen = v.at("present_kv_seqlen").to<std::size_t>();
auto present_kv_seqlen = inputs[1].lens().size() == 4 ? inputs[1].lens()[2] : 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it get set to 0 for non-4d tensors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qga_rotary_embedding doesn't take either of the kv-cache tensors as inputs and doesn't use the present_kv_seqlen parameter. Since its inputs[1] is 2d and the value of present_kv_seqlen doesn't matter, I just set it to zero and avoid trying to read lens()[2].


const auto& q_shape = inputs[0];
auto q_lens = q_shape.lens();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,9 @@ update_cache(const Present present, SeqLensK seqlens_k, Cache cache, Params para
}
}

// template <class Output, class Query, class PastKey, class PastValue, class SeqLensK, class
// Params>
// __device__ void
// concat_past_present(Output, const Query query, PastKey past_key, PastValue past_value, SeqLensK
// seqlens_k, Params params)
template <class Query,
class PastKey,
class PastValue,
class SeqLensK,
/* class Output, */ class Params>
__device__ void concat_past_present(const Query query,
PastKey past_key,
PastValue past_value,
SeqLensK seqlens_k,
/* const Output, */ Params params)
template <class Query, class PastKey, class PastValue, class SeqLensK, class Params>
__device__ void concat_past_present(
const Query query, PastKey past_key, PastValue past_value, SeqLensK seqlens_k, Params params)
{
auto ind = make_index();
auto elements =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ __device__ void calculate_softmax(AttnProbs attention_probs, // output buffer wi
}
}

template <class Output, class Input, class Probs, class SeqLensK, class Params>
__device__ void gqa_softmax(Output output, Input, Probs, SeqLensK seqlens_k, Params params)
template <class Output, class Input, class PresentKey, class Probs, class SeqLensK, class Params>
__device__ void
gqa_softmax(Output output, Input, PresentKey, Probs, SeqLensK seqlens_k, Params params)
{
const index_int elements = params.batch_size * params.num_heads * params.sequence_length;
auto ind = make_index();
Expand Down
2 changes: 1 addition & 1 deletion src/targets/gpu/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ struct miopen_apply
auto inputs = ins->inputs();

auto new_inputs = ins->inputs();
new_inputs.push_back(inputs.at(1));
new_inputs.push_back(inputs.at(2));
return mod->replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(ins->get_operator())}}),
Expand Down
60 changes: 20 additions & 40 deletions src/targets/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ struct gpu_gqa_softmax : op::group_query_attention
{
std::string name() const { return "gpu::gqa_softmax"; }

shape compute_shape(std::vector<shape> inputs) const { return inputs.at(1); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(2); }
};
MIGRAPHX_REGISTER_OP(gpu_gqa_softmax);

Expand Down Expand Up @@ -308,7 +308,6 @@ struct find_group_query_attention
auto local_window_size = v.at("local_window_size").to<int>();
auto rotary_interleaved = v.at("rotary_interleaved").to<bool>();
auto scale = v.at("scale").to<float>();
auto present_kv_seqlen = v.at("present_kv_seqlen").to<std::size_t>();

auto q_shape = inputs[0]->get_shape();
auto q_lens = q_shape.lens();
Expand Down Expand Up @@ -338,66 +337,47 @@ struct find_group_query_attention
local_window_size,
num_heads,
rotary_interleaved,
scale,
present_kv_seqlen},
scale},
rotary_inputs);
}

auto pres_k = inputs.at(3);
auto pres_v = inputs.at(4);
std::vector<instruction_ref> concat_inputs{rotary_qkv, pres_k, pres_v, inputs.at(5)};

auto concat =
mpm.get_module().insert_instruction(ins,
gpu_concat_past_present{do_rotary,
kv_num_heads,
local_window_size,
num_heads,
rotary_interleaved,
scale,
present_kv_seqlen},
concat_inputs);
auto concat = mpm.get_module().insert_instruction(
ins,
gpu_concat_past_present{
do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale},
concat_inputs);
auto id =
mpm.get_module().insert_instruction(ins, make_op("identity"), concat, pres_k, pres_v);

std::vector<instruction_ref> attn_probs_inputs{id, pres_k, pres_v, inputs.at(5)};
auto attn_probs = mpm.get_module().insert_instruction(
ins,
gpu_compute_attention_probabilities{do_rotary,
kv_num_heads,
local_window_size,
num_heads,
rotary_interleaved,
scale,
present_kv_seqlen},
gpu_compute_attention_probabilities{
do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale},
attn_probs_inputs);

std::vector<instruction_ref> softmax_inputs{rotary_qkv, attn_probs, inputs.at(5)};
auto softmax = mpm.get_module().insert_instruction(ins,
gpu_gqa_softmax{do_rotary,
kv_num_heads,
local_window_size,
num_heads,
rotary_interleaved,
scale,
present_kv_seqlen},
softmax_inputs);
std::vector<instruction_ref> softmax_inputs{rotary_qkv, pres_k, attn_probs, inputs.at(5)};
auto softmax = mpm.get_module().insert_instruction(
ins,
gpu_gqa_softmax{
do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale},
softmax_inputs);
std::vector<instruction_ref> new_inputs{rotary_qkv, pres_k, pres_v, inputs.at(5), softmax};

auto get_tuple_elm_0 = std::next(ins);
auto get_tuple_elm_1 = std::next(get_tuple_elm_0);
auto get_tuple_elm_2 = std::next(get_tuple_elm_1);
mpm.get_module().replace_instruction(get_tuple_elm_2, pres_v);
mpm.get_module().replace_instruction(get_tuple_elm_1, pres_k);
mpm.get_module().replace_instruction(get_tuple_elm_0,
gpu_compute_attention_scores{do_rotary,
kv_num_heads,
local_window_size,
num_heads,
rotary_interleaved,
scale,
present_kv_seqlen},
new_inputs);
mpm.get_module().replace_instruction(
get_tuple_elm_0,
gpu_compute_attention_scores{
do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale},
new_inputs);
}
};

Expand Down
2 changes: 1 addition & 1 deletion test/verify/test_group_query_attention_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct test_group_query_attention_gen : verify_program<test_group_query_attentio
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> query_lens{1, 1, 12288};
std::vector<size_t> kv_lens{1, 32, 4096, 128};
std::vector<size_t> kv_lens{1, 32, 2048, 128};
std::vector<size_t> slk_lens{1, 1};
std::vector<size_t> tsl_lens{1, 1};
std::vector<size_t> cs_cache_lens{4096, 64};
Expand Down
Loading