diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index dc564acfbcb..5be8fd4a25e 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -174,12 +174,53 @@ static auto any_input(Ms... ms) return match::any_of[match::inputs()](match::any(ms...).bind("input")); } +bool is_valid_broadcast(const instruction_ref b, const std::vector& reduce_axes) +{ + std::vector broadcast_axes; + auto bstrides = b->get_shape().strides(); + + for(size_t i = 0; i < bstrides.size(); ++i) + { + if(bstrides.at(i) == 0) + broadcast_axes.push_back(i); + } + + return broadcast_axes == reduce_axes; +} + +template +static auto match_broadcast_axes(M m) +{ + return match::make_basic_fun_matcher( + [=](match::matcher_context& ctx, instruction_ref ins) -> optional { + optional result = m.match(ctx, ins); + if(contains(ctx.instructions, "broadcast")) + { + instruction_ref reduce; + if(ins->get_operator().name() == "fused_reduce") + { + reduce = ins; + } + else + { + assert(contains(ctx.instructions, "reduce")); + reduce = ctx.instructions["reduce"]; + } + auto axes = reduce->get_operator().to_value().at("axes").to_vector(); + auto broadcast = ctx.instructions["broadcast"]; + if(not is_valid_broadcast(broadcast, axes)) + return nullopt; + } + return result; + }); +} + static auto match_broadcastable_input(const std::string& op, const std::string& name) { auto match_op = match::name(op)(used_once_except_broadcast()).bind(name); auto match_op_input = any_input(match_op, match::used_once()); auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once()); - return match::any_of(match_op_input, broadcast_match_op_input); + return match::any_of(match_op_input, match_broadcast_axes(broadcast_match_op_input)); } static void finalize_reduce_module(module_ref m) diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 55fad66e5ac..3da0a39b0a2 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -97,6 +97,44 @@ TEST_CASE(pointwise_reduce) EXPECT(p1 == p2); } +TEST_CASE(pointwise_reduce_unfusable_broadcast) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 1, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); + auto addb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), add); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), addb); + mm->add_return({rsum}); + } + run_pass(p1); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add")); + auto addb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), add); + auto rsum = + add_reduce(p2, + "main:reduce_sum0", + {addb}, + {2}, + [&](auto* rm, const auto& inputs, const auto& axes) { + return rm->add_instruction( + migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]); + }); + mm->add_return({rsum}); + } + EXPECT(p1 == p2); +} + TEST_CASE(scalar_multibroadcast) { // Matches the find_pointwise_reduce matcher, but input x has a (scalar) shape @@ -283,6 +321,43 @@ TEST_CASE(reduce_pointwise) EXPECT(p1 == p2); } +TEST_CASE(reduce_pointwise_unfusable_broadcast) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 1, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), rsum); + auto yb = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), y); + auto add = add_pointwise(p1, "main:pointwise0", {rsumb, yb}, single_pointwise("add")); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto rsum = add_reduce( + p2, "main:reduce_sum0", {x}, {2}, [&](auto* rm, const auto& inputs, const auto& axes) { + return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + }); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), rsum); + auto yb = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), y); + auto add = add_pointwise(p2, "main:pointwise0", {rsumb, yb}, single_pointwise("add")); + mm->add_return({add}); + } + EXPECT(p1 == p2); +} + TEST_CASE(reduce_reduce) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; @@ -325,6 +400,57 @@ TEST_CASE(reduce_reduce) EXPECT(p1 == p2); } +TEST_CASE(reduce_reduce_unfusable_broadcast) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 1, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), rsum); + auto xb = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), x); + auto rsumdiff = add_pointwise(p1, "main:pointwise0", {rsumb, xb}, single_pointwise("sub")); + auto rsum2 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), rsumdiff); + auto sqrt = add_pointwise(p1, "main:pointwise1", {rsum2}, single_pointwise("sqrt")); + mm->add_return({sqrt}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = add_reduce( + p2, "main:reduce_sum0", {x}, {2}, [&](auto* rm, const auto& inputs, const auto& axes) { + return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + }); + + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), rsum); + auto xb = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), x); + + auto sqrt = add_reduce( + p2, + "main:pointwise0:main:reduce_sum1:main:pointwise1", + {rsumb, xb}, + {2}, + [&](auto* rm, const auto& inputs, const auto& axes) { + auto rsumdiff = add_pointwise( + p2, rm, "main:pointwise0", {inputs[0], inputs[1]}, single_pointwise("sub")); + auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + rsumdiff); + return add_pointwise(p2, rm, "main:pointwise1", {rsum2}, single_pointwise("sqrt")); + }); + mm->add_return({sqrt}); + } + EXPECT(p1 == p2); +} + TEST_CASE(parallel_reduce_reduce) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; diff --git a/test/include/layernorm.hpp b/test/include/layernorm.hpp index ed8cc008bd7..d800500345e 100644 --- a/test/include/layernorm.hpp +++ b/test/include/layernorm.hpp @@ -62,3 +62,43 @@ inline migraphx::instruction_ref add_layernorm(migraphx::module& m, m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias); return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast); } + +inline migraphx::instruction_ref add_pointwise_layernorm(migraphx::module& m, + migraphx::instruction_ref x, + const std::vector& dims, + float eps = 1e-12f) +{ + auto mgx_type = x->get_shape().type(); + auto scale = m.add_parameter("scale", migraphx::shape{mgx_type, {1, 1, dims.back()}}); + auto bias = m.add_parameter("bias", migraphx::shape{mgx_type, {1, 1, dims.back()}}); + + auto epsilon = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {eps}}); + auto one = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {1}}); + + auto mean = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), x); + auto mean_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean); + auto x_minus_mean = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast); + auto sqdiff = m.add_instruction(migraphx::make_op("sqdiff"), x, mean_mbcast); + auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), sqdiff); + + auto epsilon_mbcast = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", var->get_shape().lens()}}), epsilon); + auto var_stable = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast); + auto inv_stddev_x = m.add_instruction(migraphx::make_op("rsqrt"), var_stable); + auto inv_stddev_x_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), inv_stddev_x); + auto norm = m.add_instruction(migraphx::make_op("mul"), x_minus_mean, inv_stddev_x_mbcast); + + auto one_mbcast = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", scale->get_shape().lens()}}), one); + auto add_scale = m.add_instruction(migraphx::make_op("add"), scale, one_mbcast); + auto scale_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), add_scale); + auto scale_norm = m.add_instruction(migraphx::make_op("mul"), norm, scale_mbcast); + + auto bias_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias); + + return m.add_instruction(migraphx::make_op("add"), scale_norm, bias_mbcast); +} diff --git a/test/verify/test_layernorm.cpp b/test/verify/test_layernorm.cpp index eef8acd5de7..8fc8389413d 100644 --- a/test/verify/test_layernorm.cpp +++ b/test/verify/test_layernorm.cpp @@ -185,3 +185,16 @@ struct test_add_layernorm_add_gemm_nonstd : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector dims = {1, 9, 6}; + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims}); + add_pointwise_layernorm(*mm, x, dims); + return p; + } +};