Skip to content

Commit

Permalink
Dont convert squeeze/unsqueeze to reshape when fusing mlir (#3533)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Oct 21, 2024
1 parent b73defb commit 797aaf5
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 12 deletions.
5 changes: 1 addition & 4 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ const auto& reshaper_names()
"broadcast",
"contiguous",
"reshape",
"lazy_reshape",
"squeeze",
"flatten",
"unsqueeze"
Expand All @@ -198,10 +199,6 @@ get_fusable_input_op_stream(instruction_ref lower_input)
while(contains(reshaper_names(), upper_input->name()))
{
operation op = upper_input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, upper_input->name()))
{
op = migraphx::make_op("reshape", {{"dims", upper_input->get_shape().lens()}});
}
op_stream.push_back(op);
upper_input = upper_input->inputs().at(0);
}
Expand Down
11 changes: 9 additions & 2 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,11 @@ struct mlir_program
return result;
}

static bool is_reshape(const std::string& name)
{
return contains({"reshape", "lazy_reshape", "squeeze", "unsqueeze", "flatten"}, name);
}

static std::string get_name(instruction_ref ins)
{
if(ins->name() == "@return")
Expand All @@ -627,6 +632,8 @@ struct mlir_program
return "migraphx.literal";
if(ins->name() == "unpack_int4")
return "migraphx.unpack";
if(is_reshape(ins->name()))
return "migraphx.reshape";
return "migraphx." + ins->name();
}

Expand All @@ -637,8 +644,8 @@ struct mlir_program

// Reshape operator can have dim 0 or -1.
// Avoid passing those on to MLIR:
if(op.name() == "reshape")
v["dims"] = ins->get_shape().lens();
if(is_reshape(op.name()))
v = {{"dims", ins->get_shape().lens()}};

if(op.name() == "convolution" or op.name() == "quant_convolution")
{
Expand Down
11 changes: 5 additions & 6 deletions test/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ TEST_CASE(dot_reshapes_add)
auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]);
auto dot_trans = pm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot);
auto dot_rsp = pm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 3}}}),
dot_trans);
auto dot_rsp = pm->add_instruction(migraphx::make_op("squeeze"), dot_trans);
auto add = pm->add_instruction(migraphx::make_op("add"), dot_rsp, inputs[0]);
return std::make_tuple(dot->get_operator(), add);
});
Expand Down Expand Up @@ -499,8 +498,8 @@ TEST_CASE(dequantizelinear_dot)
{y, scalelit, zplit, x},
{"x0", "x1", "x2", "x3"},
[=](auto* pm, const auto& inputs) {
auto unsqueeze1 = pm->add_instruction(
migraphx::make_op("reshape", {{"dims", {2, 2, 1, 2}}}), inputs[1]);
auto unsqueeze1 =
pm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), inputs[1]);
auto broadcast1 = pm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze1);
auto reshape1 = pm->add_instruction(
Expand All @@ -509,8 +508,8 @@ TEST_CASE(dequantizelinear_dot)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}),
reshape1);

auto unsqueeze2 = pm->add_instruction(
migraphx::make_op("reshape", {{"dims", {2, 2, 1, 2}}}), inputs[2]);
auto unsqueeze2 =
pm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), inputs[2]);
auto broadcast2 = pm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze2);
auto reshape2 = pm->add_instruction(
Expand Down
30 changes: 30 additions & 0 deletions test/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,36 @@ module {
EXPECT(verify_mlir(m));
}

TEST_CASE(unsqueeze_dot_add)
{
std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_unsqueeze_dot_add(%arg0: !migraphx.shaped<5x4xf32, 4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} {
%0 = migraphx.reshape %arg0 {dims = [1, 5, 4]} : <5x4xf32, 4x1> -> <1x5x4xf32, 20x4x1>
%1 = migraphx.dot %0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%2 = migraphx.add %1, %arg2 : <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1>
return %2 : !migraphx.shaped<1x5x3xf32, 15x3x1>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}});
auto arg2 = m.add_parameter("arg2", {migraphx::shape::float_type, {1, 5, 3}});
auto unsqueeze = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), arg0);
auto dot = m.add_instruction(migraphx::make_op("dot"), unsqueeze, arg1);
auto add = m.add_instruction(migraphx::make_op("add"), dot, arg2);
m.add_return({add});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
auto mlir_output_with_attrs =
migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}});
CHECK(encode(s) == encode(mlir_output_with_attrs));
EXPECT(verify_mlir(m));
}

TEST_CASE(conv_int8_dequantize_quantize)
{
std::string mlir_output = R"__migraphx__(
Expand Down

0 comments on commit 797aaf5

Please sign in to comment.