From 797aaf595ec6f5d2d7fdde8a4dc5eff8d9614fce Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Mon, 21 Oct 2024 08:59:02 -0500 Subject: [PATCH] Dont convert squeeze/unsqueeze to reshape when fusing mlir (#3533) --- src/targets/gpu/fuse_mlir.cpp | 5 +---- src/targets/gpu/mlir.cpp | 11 +++++++++-- test/gpu/fuse_mlir.cpp | 11 +++++------ test/gpu/mlir.cpp | 30 ++++++++++++++++++++++++++++++ 4 files changed, 45 insertions(+), 12 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index b2d2804003b..1b759a89c7c 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -182,6 +182,7 @@ const auto& reshaper_names() "broadcast", "contiguous", "reshape", + "lazy_reshape", "squeeze", "flatten", "unsqueeze" @@ -198,10 +199,6 @@ get_fusable_input_op_stream(instruction_ref lower_input) while(contains(reshaper_names(), upper_input->name())) { operation op = upper_input->get_operator(); - if(contains({"squeeze", "flatten", "unsqueeze"}, upper_input->name())) - { - op = migraphx::make_op("reshape", {{"dims", upper_input->get_shape().lens()}}); - } op_stream.push_back(op); upper_input = upper_input->inputs().at(0); } diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 2442e2e8aed..32c903857a8 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -619,6 +619,11 @@ struct mlir_program return result; } + static bool is_reshape(const std::string& name) + { + return contains({"reshape", "lazy_reshape", "squeeze", "unsqueeze", "flatten"}, name); + } + static std::string get_name(instruction_ref ins) { if(ins->name() == "@return") @@ -627,6 +632,8 @@ struct mlir_program return "migraphx.literal"; if(ins->name() == "unpack_int4") return "migraphx.unpack"; + if(is_reshape(ins->name())) + return "migraphx.reshape"; return "migraphx." + ins->name(); } @@ -637,8 +644,8 @@ struct mlir_program // Reshape operator can have dim 0 or -1. // Avoid passing those on to MLIR: - if(op.name() == "reshape") - v["dims"] = ins->get_shape().lens(); + if(is_reshape(op.name())) + v = {{"dims", ins->get_shape().lens()}}; if(op.name() == "convolution" or op.name() == "quant_convolution") { diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 032cad2ff17..06034f77976 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -109,8 +109,7 @@ TEST_CASE(dot_reshapes_add) auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); auto dot_trans = pm->add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot); - auto dot_rsp = pm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 3}}}), - dot_trans); + auto dot_rsp = pm->add_instruction(migraphx::make_op("squeeze"), dot_trans); auto add = pm->add_instruction(migraphx::make_op("add"), dot_rsp, inputs[0]); return std::make_tuple(dot->get_operator(), add); }); @@ -499,8 +498,8 @@ TEST_CASE(dequantizelinear_dot) {y, scalelit, zplit, x}, {"x0", "x1", "x2", "x3"}, [=](auto* pm, const auto& inputs) { - auto unsqueeze1 = pm->add_instruction( - migraphx::make_op("reshape", {{"dims", {2, 2, 1, 2}}}), inputs[1]); + auto unsqueeze1 = + pm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), inputs[1]); auto broadcast1 = pm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze1); auto reshape1 = pm->add_instruction( @@ -509,8 +508,8 @@ TEST_CASE(dequantizelinear_dot) migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape1); - auto unsqueeze2 = pm->add_instruction( - migraphx::make_op("reshape", {{"dims", {2, 2, 1, 2}}}), inputs[2]); + auto unsqueeze2 = + pm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), inputs[2]); auto broadcast2 = pm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze2); auto reshape2 = pm->add_instruction( diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index 586b102aafc..5453f60004a 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -361,6 +361,36 @@ module { EXPECT(verify_mlir(m)); } +TEST_CASE(unsqueeze_dot_add) +{ + std::string mlir_output = R"__migraphx__( +module { + func.func @mlir_unsqueeze_dot_add(%arg0: !migraphx.shaped<5x4xf32, 4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} { + %0 = migraphx.reshape %arg0 {dims = [1, 5, 4]} : <5x4xf32, 4x1> -> <1x5x4xf32, 20x4x1> + %1 = migraphx.dot %0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> + %2 = migraphx.add %1, %arg2 : <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1> + return %2 : !migraphx.shaped<1x5x3xf32, 15x3x1> + } +} +)__migraphx__"; + migraphx::module m; + auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {5, 4}}); + auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}}); + auto arg2 = m.add_parameter("arg2", {migraphx::shape::float_type, {1, 5, 3}}); + auto unsqueeze = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), arg0); + auto dot = m.add_instruction(migraphx::make_op("dot"), unsqueeze, arg1); + auto add = m.add_instruction(migraphx::make_op("add"), dot, arg2); + m.add_return({add}); + auto s = migraphx::gpu::dump_mlir(m); + // Skip test if MLIR is not enabled + if(s.empty()) + return; + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); + EXPECT(verify_mlir(m)); +} + TEST_CASE(conv_int8_dequantize_quantize) { std::string mlir_output = R"__migraphx__(