Skip to content

Commit

Permalink
Logic improved
Browse files Browse the repository at this point in the history
  • Loading branch information
aarushijai committed Oct 3, 2024
1 parent b0edd34 commit 69457bd
Showing 1 changed file with 57 additions and 36 deletions.
93 changes: 57 additions & 36 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,14 +422,16 @@ struct find_mul_add
{
auto matcher() const
{
return match::name("mul")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
return match::name("mul")(
match::none_of[match::outputs()](match::name("convolution")),
match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}

void apply(module& m, const match::matcher_result& r) const
Expand All @@ -440,55 +442,74 @@ struct find_mul_add
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);

if(a_ins->get_shape().scalar()){
return;
}
auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
/* Delete this later
struct find_mul_add
{
auto matcher() const
{
return match::name("mul")(
match::none_of[match::outputs()](match::name("convolution")),
match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
*/

struct find_scalar_mul_conv
{
auto matcher() const
{
return match::name("mul")(
match::either_arg(0, 1)(
conv_const_weights().bind("conv"),
match::either_arg(0, 1)(
match::name("broadcast", "multibroadcast", "constant").bind("scalar"),
match::any().bind("scalar")
)
)
);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto conv_ins = r.instructions["conv"];
auto scalar_ins = r.instructions["scalar"];
auto w_ins = r.instructions["w"];

if(scalar_ins->get_shape().elements() != 1)
return;
const auto& w_shape = w_ins->get_shape().lens();
match::is_constant().bind("scalar"),
match::name("convolution").bind("conv")
));
}

if(scalar_ins->get_shape().ndim() != w_shape.size())
{
scalar_ins = m.insert_instruction(ins, make_op("broadcast", {{"axis", 0}, {"out_lens", w_shape}}), scalar_ins);
}
void apply(module& m, const match::matcher_result& r) const

Check warning on line 493 in src/simplify_algebra.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_algebra.cpp#L493

Added line #L493 was not covered by tests
{
auto ins = r.result;
auto scalar = r.instructions["scalar"];
auto conv_ins = r.instructions["conv"];

Check warning on line 497 in src/simplify_algebra.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_algebra.cpp#L495-L497

Added lines #L495 - L497 were not covered by tests

auto new_weights = m.insert_instruction(ins, make_op("mul"), scalar_ins, w_ins);
// Get the convol's input and weights
auto conv_input = conv_ins->inputs().front(); // input to conv
auto conv_weights = conv_ins->inputs().back(); // weights of the conv

Check warning on line 501 in src/simplify_algebra.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_algebra.cpp#L500-L501

Added lines #L500 - L501 were not covered by tests

auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights);
auto scaled_weights = m.insert_instruction(ins, make_op("mul"), scalar, conv_weights);

Check warning on line 503 in src/simplify_algebra.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_algebra.cpp#L503

Added line #L503 was not covered by tests

// new conv with modified weights
auto new_conv = m.insert_instruction(ins, conv_ins->get_operator(), conv_input, scaled_weights);

Check warning on line 506 in src/simplify_algebra.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_algebra.cpp#L506

Added line #L506 was not covered by tests

m.replace_instruction(ins, new_conv);
}

Check warning on line 509 in src/simplify_algebra.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_algebra.cpp#L508-L509

Added lines #L508 - L509 were not covered by tests
};


struct find_dot_add
{
auto matcher() const
Expand Down

0 comments on commit 69457bd

Please sign in to comment.