Skip to content

Commit

Permalink
fix bug in find_concat_op when input broadcasts are on different axes (
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar authored Aug 20, 2024
1 parent 43d3bfc commit 873b335
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,18 @@ struct find_concat_op
op.attributes().contains("pointwise");
}

static bool is_valid_concat(std::vector<instruction_ref> ins, size_t axis)
{
auto concat_lens = ins.front()->get_shape().lens();
concat_lens.erase(concat_lens.begin() + axis);

return std::all_of(ins.begin(), ins.end(), [&](auto i) {
auto lens = i->get_shape().lens();
lens.erase(lens.begin() + axis);
return lens == concat_lens;
});
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
Expand Down Expand Up @@ -852,6 +864,8 @@ struct find_concat_op
std::transform(start, last, std::back_inserter(inputs), [&](auto j) {
return j->inputs().at(i);
});
if(not is_valid_concat(inputs, iaxis))
return {start, last};
auto concat =
m.insert_instruction(ins, make_op("concat", {{"axis", iaxis}}), inputs);
concats.push_back(concat);
Expand Down
21 changes: 21 additions & 0 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3999,4 +3999,25 @@ TEST_CASE(conv_concat_group)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(find_concat_different_broadcast_axes)
{
migraphx::shape s1{migraphx::shape::float_type, {128, 1, 1, 1, 1}};
migraphx::shape s2{migraphx::shape::float_type, {1, 3, 1, 1, 1}};
migraphx::module m1;
{
auto l1 = m1.add_literal(migraphx::generate_literal(s1, 1));
auto l2 = m1.add_literal(migraphx::generate_literal(s2, 2));
auto bc1 = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {128, 3, 224, 224, 1}}}), l1);
auto bc2 = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {128, 3, 224, 224, 1}}}), l2);
auto cat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 4}}), bc1, bc2);
m1.add_return({cat});
};

migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit 873b335

Please sign in to comment.