Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalar multiplication and followed by conv #3478

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
44 changes: 44 additions & 0 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,55 @@
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);

if(a_ins->get_shape().scalar()){
return;
}

auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};

struct find_scalar_mul_conv
{
auto matcher() const
{
return match::name("mul")(
match::either_arg(0, 1)(
conv_const_weights().bind("conv"),
match::either_arg(0, 1)(
match::name("broadcast", "multibroadcast", "constant").bind("scalar"),
match::any().bind("scalar")

Check warning on line 462 in src/simplify_algebra.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Too many nested parentheses can affect readability; consider using variables instead. [migraphx-MatcherNestedParentheses]
)
)
);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto conv_ins = r.instructions["conv"];
auto scalar_ins = r.instructions["scalar"];
auto w_ins = r.instructions["w"];

if(scalar_ins->get_shape().elements() != 1)
return;
const auto& w_shape = w_ins->get_shape().lens();

if(scalar_ins->get_shape().ndim() != w_shape.size())
{
scalar_ins = m.insert_instruction(ins, make_op("broadcast", {{"axis", 0}, {"out_lens", w_shape}}), scalar_ins);
}

auto new_weights = m.insert_instruction(ins, make_op("mul"), scalar_ins, w_ins);

auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights);

m.replace_instruction(ins, new_conv);
}
};

struct find_dot_add
{
auto matcher() const
Expand Down Expand Up @@ -1981,6 +2024,7 @@
find_dot_slice{},
find_dot_mul{},
find_mul_add{},
find_scalar_mul_conv{},
find_unit_ops{},
find_neg_unit_ops{},
eliminate_zero_point{},
Expand Down
3 changes: 2 additions & 1 deletion src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,8 @@ struct find_unary_shape_transforms
{
return ins->name() == "@literal" or
ins->get_operator().attributes().contains("pointwise") or
contains(ins->name(), "reduce");
contains(ins->name(), "reduce") or
ins->get_operator().attributes().contains("scalar"); //For scalar operation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What operators have a scalar attribute?

}

void apply(module& m, const match::matcher_result& mr) const
Expand Down
Loading